From 9f034d110544ca608a1077e00b0dd0f1b96375e8 Mon Sep 17 00:00:00 2001 From: lizhao Date: Tue, 9 Jun 2026 14:23:25 +0800 Subject: [PATCH] update --- Decoder.py | 127 ++++++++++++++++++++++------------------------- README.md | 6 ++- Zmq/zmqServer.py | 16 +++++- config.ini | 1 + datamock.py | 2 +- 5 files changed, 82 insertions(+), 70 deletions(-) diff --git a/Decoder.py b/Decoder.py index c84977c..a3ee1cd 100644 --- a/Decoder.py +++ b/Decoder.py @@ -26,6 +26,8 @@ from SSVEP.dwfbcca import FbccaDw from collections import deque from Zmq.filterProcess import SlidingFilter +save_train_data = int(IniRead('system', 'save_train_data', 0)) + def get_root_path(): """ Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录) @@ -209,11 +211,10 @@ class Decoder_main(threading.Thread): elif self.decoder_class == 'mi': self.decoder_MI() else: - if self.zmqServer.open_Impedance == False: # 非阻抗检测状态 - if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: - time.sleep(0.005) - continue; - self.zmqServer.paradigmBuffer.getData(25) + if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: + time.sleep(0.005) + continue; + self.zmqServer.paradigmBuffer.getData(25) except Exception as e: algo_log(f"Decoder Loop Error: {e}") time.sleep(0.1) # Prevent CPU spin if error is persistent @@ -223,31 +224,32 @@ class Decoder_main(threading.Thread): self.zmqServer.StartDecode = False self.decodingSteps = 1 self.zmqServer.paradigmBuffer.resetAllPara() - print('启动预测') + algo_log('启动SSVEP预测', level="DEBUG") if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50: time.sleep(0.005) return if self.zmqServer.open_Impedance: # 阻抗检测状态不解码 return data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50) + algo_log(f"SSVEP取出的:{data.shape}, data = {data[:20]}", level="DEBUG") data = data[:self.n_chan, :] if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.warmFilter(data) # 预热 self.decodingSteps = 2 - print('预热数据完成。开始预测') + algo_log('SSVEP预热数据完成。开始预测', level="DEBUG") return if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中 choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount) self.calculateCount += 1 if choosenNum != -1 and self.is_valid_signal(data): self.decodingSteps = 3 - print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) + algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG") self.calculateCount = 0 if self.decodingSteps == 3: # 发送解码后的信息 self.zmqServer.broadcast_message('result', int(choosenNum)) self.decodingSteps = 0 - print('发送给界面完成。') + algo_log('SSVEP发送给界面完成。', level="DEBUG") def decoder_SSMVEP(self): '''模型训练''' @@ -255,34 +257,28 @@ class Decoder_main(threading.Thread): self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成 self.trainData = np.array(self.trainData) self.trainLabel = np.array(self.trainLabel) - print(np.shape(self.trainData), (self.trainLabel)) - # 保存多个数组到文件 - # np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel) - # self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) + algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG") + if save_train_data == 1: + now_str = datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"{now_str}.npz" + np.savez(save_path, array1=self.trainData, array2=self.trainLabel) self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] - print('模型训练完成', formatted_time) + algo_log(f"SSMVEP模型训练完成,时间:{formatted_time}", level="DEBUG") self.load_model = True self.zmqServer.broadcast_message('paradigm', 1) '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train': # 训练状态 - - if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ self.train_epoch[1] + self.zmqServer.event_inner_idx: - self.currentLabel = self.zmqServer.currentLabel - - print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 - - print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx]) + algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] - print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1]) if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( self.trainLabel, list) \ and self.trainLabel.count(self.currentLabel) < self.single_train: @@ -301,15 +297,14 @@ class Decoder_main(threading.Thread): self.zmqServer.StartDecode = False now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] - print('启动预测 ', formatted_time) - + algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG") if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.zmqServer.event_inner_idx: time.sleep(0.0001) return data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据 - print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx]) + algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") data = self.preprocess(data[:self.n_chan, :]) # 预处理 data = data[:, self.zmqServer.event_inner_idx + self.interval_epoch[ @@ -320,16 +315,14 @@ class Decoder_main(threading.Thread): choosenNum, features_2 = self.decoder.predict(pad_eeg_test) if isinstance(choosenNum, np.ndarray): choosenNum = choosenNum[0] - print('结果:', choosenNum, 'rho: ', sorted(features_2[0]), - sorted(features_2[0])[-1] - sorted(features_2[0])[-2]) + algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG") self.zmqServer.broadcast_message('result', int(choosenNum)) - print('发送给界面完成。') + algo_log("SSMVEP发送给界面完成。", level="DEBUG") else: # 休息状态 - if self.zmqServer.open_Impedance == False: # 非阻抗检测状态 - if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: - time.sleep(0.005) - return - self.zmqServer.paradigmBuffer.getData(25) + if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.zmqServer.paradigmBuffer.getData(25) def decoder_MI(self): '''模型训练''' @@ -339,7 +332,11 @@ class Decoder_main(threading.Thread): self.train_started = True self.trainData = np.array(self.trainData) self.trainLabel = np.array(self.trainLabel) + 1 - # print('训练集:',np.shape(self.trainData), (self.trainLabel)) + algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG") + if save_train_data == 1: + now_str = datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = f"{now_str}.npz" + np.savez(save_path, array1=self.trainData, array2=self.trainLabel) p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型 p.start() self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath, @@ -350,7 +347,7 @@ class Decoder_main(threading.Thread): try: result = self.mp_result_queue.get_nowait() if result['status'] == 'success': - print("模型训练完成,加载新模型") + algo_log("MI模型训练完成,加载新模型", level="DEBUG") # 调用模型 self.model = torch.load(self.modelPath, weights_only=False) self.model.eval() @@ -363,45 +360,42 @@ class Decoder_main(threading.Thread): self.load_model = True self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机 else: - print("训练失败:", result['msg']) + algo_log("MI训练失败: " + result['msg'], level="DEBUG") except Empty: pass # 还没完成 except Exception as e: - print('模型调用失败: ', e) + algo_log("MI模型训练失败: " + str(e), level="DEBUG") '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 - if self.zmqServer.StartTrain: - self.currentLabel = self.zmqServer.currentLabel - self.zmqServer.StartTrain = False - if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ - self.interval_epoch[1] \ - + self.zmqServer.event_inner_idx: + if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ + self.interval_epoch[1] + self.zmqServer.event_inner_idx: + algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG") + originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据 + algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") + trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[ + 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] + algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG") + if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, + list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG") + self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ + 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) + self.plotLabel.append(self.currentLabel) + else: time.sleep(0.0001) return - print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) - originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据 - print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx]) - trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 - trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] - print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1]) - if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, - list) \ - and self.trainLabel.count(self.currentLabel) < self.single_train: - self.trainData.append(trainTrial) - self.trainLabel.append(self.currentLabel) - print('训练集:', np.shape(self.trainData)) - self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) - self.plotLabel.append(self.currentLabel) elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 if self.zmqServer.StartDecode: self.zmqServer.StartDecode = False now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] - print('启动预测 ', formatted_time) + algo_log(f"MI启动预测 {formatted_time}", level="DEBUG") if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ @@ -409,7 +403,7 @@ class Decoder_main(threading.Thread): time.sleep(0.0001) return originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 - print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx]) + algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") start = time.time() data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 data = data[:, @@ -426,16 +420,15 @@ class Decoder_main(threading.Thread): Cls = self.model(test_data) y_pred = torch.max(Cls, 1)[1] self.plotLabel.append(int(y_pred.item())) - print('运动意图识别: ', y_pred) + algo_log(f"MI运动意图识别: {y_pred}") self.zmqServer.broadcast_message('paradigm', int(y_pred.item())) end = time.time() print(f'发送给界面完成,耗时{end - start:.3f}s。') else: # 休息状态 - if self.zmqServer.open_Impedance == False: # 非阻抗检测状态 - if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: - time.sleep(0.005) - return - self.zmqServer.paradigmBuffer.getData(25) + if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.zmqServer.paradigmBuffer.getData(25) # def decoder_concentration(self): # if self.zmqServer.state_mode == 'predict': diff --git a/README.md b/README.md index ef177f0..4e80830 100644 --- a/README.md +++ b/README.md @@ -19,4 +19,8 @@ source activate 3in1Py310 python runDecoder.py python datamock.py python ZeroMQClient_mock.py -python system_test.py \ No newline at end of file +python system_test.py + + +# 遗留问题 +1. mvep是否要把list freq 开放到config \ No newline at end of file diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index 19453ef..2cea875 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -250,6 +250,20 @@ class zmqServer(threading.Thread): self.decoder_switch = True elif method == "train": self.state_mode = 'train' + resp = { + "method": "train_response", + "params": { + "code": 200, + "message": "ok" + } + } + try: + resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8") + self.cmd_socket.send_multipart([ident, b"", resp_bytes]) + algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG") + except Exception as e: + algo_log(f"train 命令回复失败: {e}", level="ERROR") + return elif method == "predict": self.state_mode = 'predict' if params == 1: #开始解码 @@ -360,7 +374,7 @@ class zmqServer(threading.Thread): # -------------------------- 主循环 -------------------------- def run(self): self.running = True - algo_log(f"ZMQ服务器启动成功 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") + algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") try: while self.running: diff --git a/config.ini b/config.ini index 370c50e..1eb7bd5 100644 --- a/config.ini +++ b/config.ini @@ -18,6 +18,7 @@ Upper_Port = 8088 Serial_port = COM44 algo_log_level = DEBUG console_output = 1 +save_train_data = 0 ; 64 导设备配置 [device_type_1] diff --git a/datamock.py b/datamock.py index 54e26ec..2c7f3e2 100644 --- a/datamock.py +++ b/datamock.py @@ -12,7 +12,7 @@ EEG_FREQ = 10 # EEG 正弦波频率 Hz EEG_AMP = 100.0 # EEG 幅值 100μV LABEL_INTERVAL = 5 # 标签间隔秒数 # SERVER_ADDR = 'tcp://127.0.0.1:8100' -SERVER_ADDR = 'tcp://127.0.0.1:8100' +SERVER_ADDR = 'tcp://10.200.27.140:8100' # 发送间隔: 每包 5 采样点 / 250Hz = 20ms PKT_INTERVAL = N_SAMPLES_PER_PKT / FS