Compare commits

...

41 Commits

Author SHA1 Message Date
c27e250fad update log config 2026-06-13 19:47:27 +08:00
66c0b71b89 update ip 2026-06-13 17:35:46 +08:00
5c7b73b7a4 add log 2026-06-13 16:49:29 +08:00
9690971f43 update 2026-06-13 11:54:58 +08:00
5a5f103ef6 update log 2026-06-13 10:06:29 +08:00
b31bb18dfe update 2026-06-12 15:21:47 +08:00
38480a2ca3 remove release 2026-06-12 14:30:11 +08:00
62e7cab5be add sleep if open impedence 2026-06-12 13:56:48 +08:00
Ivey Song
b26ae2ce3c beta psd 独立线程 2026-06-12 11:33:48 +08:00
5488626112 update 2026-06-11 14:29:43 +08:00
d59b0f695f realeas v1 2026-06-11 11:55:35 +08:00
0570d41439 bug fix 2026-06-11 11:06:59 +08:00
4574798d86 release v1 2026-06-11 09:21:57 +08:00
d480107b37 update 2026-06-11 08:11:29 +08:00
2d70fc9956 update log path 2026-06-11 08:04:08 +08:00
Ivey Song
1bbe84eb56 beta psd return 2026-06-10 17:55:43 +08:00
Ivey Song
f21367bc20 betapsd 回调 2026-06-10 17:55:43 +08:00
ba4ae92647 replace print with algo_log 2026-06-10 16:04:02 +08:00
43adc6fb42 update Decoder 2026-06-10 15:18:22 +08:00
b329989181 update 2026-06-10 11:22:40 +08:00
68106d8aed nuitka package test 2026-06-10 10:57:30 +08:00
Ivey Song
506ebfd973 MI trainLabel revise 2026-06-10 10:05:08 +08:00
Ivey Song
5a2cc82100 update 2026-06-10 09:28:24 +08:00
Ivey Song
81a8d78ab2 upper mock 2026-06-10 09:25:11 +08:00
73e01782df update ssvep test case 2026-06-10 08:24:20 +08:00
b78e583bec update log 2026-06-10 07:55:34 +08:00
504e89ee47 update 2026-06-10 07:48:43 +08:00
Ivey Song
a9dbe7261b update 2026-06-09 19:30:27 +08:00
7b5f4f6eb9 update zmq log 2026-06-09 19:11:21 +08:00
0cffd1ae02 update filter parameter 2026-06-09 19:10:54 +08:00
0e5e79fcdd update filter 2026-06-09 18:30:56 +08:00
694321b52c add filter test case 2026-06-09 16:46:07 +08:00
9f034d1105 update 2026-06-09 14:23:25 +08:00
07560304ca del train 2026-06-09 10:57:28 +08:00
f47e7d914f update log 2026-06-08 19:43:44 +08:00
af4fb48737 update 2026-06-08 17:29:27 +08:00
fdddc814c7 fitler buffer with lock 2026-06-08 17:13:25 +08:00
d741e3548f buffer v1 2026-06-08 17:06:27 +08:00
509fc5a1d7 update 2026-06-08 16:07:09 +08:00
Ivey Song
67587f354b Merge branch 'master' of http://47.98.56.110:7001/lizhao/bci_algo 2026-06-08 15:59:02 +08:00
Ivey Song
d5ef2311a1 数据帧标准3帧,新增recv 接收数据 2026-06-08 15:58:42 +08:00
23 changed files with 2461 additions and 559 deletions

11
.gitignore vendored
View File

@@ -2,10 +2,14 @@
__pycache__/ __pycache__/
# Distribution / packaging # Distribution / packaging
release/
build/ build/
dist/ dist/
dist_nuitka/
# Environments upperHost_stim/
.vscode/
#!upperHost_stim/MI_headless.py
#!upperHost_stim/ssmvep_headless.py
.env .env
.venv .venv
env/ env/
@@ -24,7 +28,8 @@ venv.bak/
*.xlsx *.xlsx
*.mat *.mat
*.json *.json
*.txt
*.pth
# PyCharm # PyCharm
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself # JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself

View File

@@ -14,8 +14,8 @@ from torch.autograd import Variable
# from Device.SunnyLinker import SunnyLinker64 # from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate # from concentration.algorithm.calculate_focus import Calculate
from blinkdetection.algorithm.eye_detection import blink_detection # from blinkdetection.algorithm.eye_detection import blink_detection
from Zmq.zmqServer import zmqServer from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain from MI.Algorithm.conformer_2class import onlineTrain
@@ -26,6 +26,8 @@ from SSVEP.dwfbcca import FbccaDw
from collections import deque from collections import deque
from Zmq.filterProcess import SlidingFilter from Zmq.filterProcess import SlidingFilter
save_train_data = int(IniRead('system', 'save_train_data', 0))
def get_root_path(): def get_root_path():
""" """
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录) Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
@@ -60,6 +62,8 @@ class Decoder_main(threading.Thread):
# 注册滤波结果回调(示例:打印数据形状) # 注册滤波结果回调(示例:打印数据形状)
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples) # data: (chans, samples)
@@ -74,7 +78,7 @@ class Decoder_main(threading.Thread):
:return: :return:
''' '''
self.decoder_class = decoder_class self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs': if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.n_chan = 8 self.n_chan = 8
# self.thread_data_server.interval_inited = False # self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
@@ -96,7 +100,7 @@ class Decoder_main(threading.Thread):
elif decoder_class == 'ssmvep': elif decoder_class == 'ssmvep':
self.zmqServer.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 8 self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量 self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目 self.num_target = 2 # 分类目标数目
@@ -110,8 +114,8 @@ class Decoder_main(threading.Thread):
elif decoder_class == 'mi' or decoder_class == 'ma': elif decoder_class == 'mi' or decoder_class == 'ma':
self.zmqServer.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 21 self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位
self.single_train = 40 # 单类别数量 self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目 self.num_target = 2 # 分类目标数目
@@ -153,8 +157,8 @@ class Decoder_main(threading.Thread):
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band') # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high): def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 ssmvep [50, 550]
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch ssmevep [50, 575]
self.trainData = [] #训练数据 self.trainData = [] #训练数据
self.trainLabel = [] #训练标签 self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据 self.plotData = [] #报告分析数据
@@ -185,7 +189,7 @@ class Decoder_main(threading.Thread):
def run(self): def run(self):
while self.Runing: while self.Runing:
# 当滤波数据大于5秒时启动滤波线程 # 当滤波数据大于5秒时启动滤波线程
if self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5: if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
algo_log("启动滤波线程", level="DEBUG") algo_log("启动滤波线程", level="DEBUG")
self.sliding_filter.start() self.sliding_filter.start()
@@ -202,6 +206,9 @@ class Decoder_main(threading.Thread):
self.zmqServer.state_mode = 'rest' self.zmqServer.state_mode = 'rest'
try: try:
if self.zmqServer.open_Impedance:
time.sleep(0.005)
continue
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs': if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP() self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep': elif self.decoder_class == 'ssmvep':
@@ -209,11 +216,10 @@ class Decoder_main(threading.Thread):
elif self.decoder_class == 'mi': elif self.decoder_class == 'mi':
self.decoder_MI() self.decoder_MI()
else: else:
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态 if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005)
time.sleep(0.005) continue
continue; self.zmqServer.paradigmBuffer.getData(25)
self.zmqServer.paradigmBuffer.getData(25)
except Exception as e: except Exception as e:
algo_log(f"Decoder Loop Error: {e}") algo_log(f"Decoder Loop Error: {e}")
time.sleep(0.1) # Prevent CPU spin if error is persistent time.sleep(0.1) # Prevent CPU spin if error is persistent
@@ -223,71 +229,70 @@ class Decoder_main(threading.Thread):
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
self.decodingSteps = 1 self.decodingSteps = 1
self.zmqServer.paradigmBuffer.resetAllPara() self.zmqServer.paradigmBuffer.resetAllPara()
print('启动预测') algo_log('启动SSVEP预测', level="DEBUG")
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
time.sleep(0.005) time.sleep(0.005)
return return
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码 if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
return return
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50) data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
# algo_log(f"SSVEP取出的{data.shape}, data = {data[:, :10]}", level="DEBUG")
data = data[:self.n_chan, :] data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热 self.dw.warmFilter(data) # 预热
self.decodingSteps = 2 self.decodingSteps = 2
print('预热数据完成。开始预测') algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
return return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中 if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount) choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1 self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data): if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3 self.decodingSteps = 3
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
self.calculateCount = 0 self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息 if self.decodingSteps == 3: # 发送解码后的信息
self.zmqServer.broadcast_message('result', int(choosenNum)) self.zmqServer.broadcast_message('result', int(choosenNum))
self.decodingSteps = 0 self.decodingSteps = 0
print('发送给界面完成。') algo_log('SSVEP发送给界面完成。', level="DEBUG")
def decoder_SSMVEP(self): def decoder_SSMVEP(self):
'''模型训练''' '''模型训练'''
if self.load_model == False and all( if self.load_model == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成 self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成
self.trainData = np.array(self.trainData) self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) self.trainLabel = np.array(self.trainLabel)
print(np.shape(self.trainData), (self.trainLabel)) algo_log(f"开始SSMVEP模型训练数据形状{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
# 保存多个数组到文件 if save_train_data == 1:
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel) now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time) algo_log(f"SSMVEP模型训练完成时间{formatted_time}", level="DEBUG")
self.load_model = True self.load_model = True
self.zmqServer.broadcast_message('paradigm', 1) self.zmqServer.broadcast_message('paradigm', 1)
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain: if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
self.train_epoch[1] \ trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
+ self.zmqServer.event_inner_idx: trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else:
time.sleep(0.0001) time.sleep(0.0001)
return return
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict': # 测试状态 elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成 if self.load_model == False: # 模型尚未训练完成
@@ -298,15 +303,15 @@ class Decoder_main(threading.Thread):
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx: + self.zmqServer.event_inner_idx:
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
time.sleep(0.0001) time.sleep(0.0001)
return return
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据 data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx]) algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
data = self.preprocess(data[:self.n_chan, :]) # 预处理 data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:, data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[ self.zmqServer.event_inner_idx + self.interval_epoch[
@@ -317,26 +322,28 @@ class Decoder_main(threading.Thread):
choosenNum, features_2 = self.decoder.predict(pad_eeg_test) choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray): if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0] choosenNum = choosenNum[0]
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]), algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
self.zmqServer.broadcast_message('result', int(choosenNum)) self.zmqServer.broadcast_message('result', int(choosenNum))
print('发送给界面完成。') algo_log("SSMVEP发送给界面完成。", level="DEBUG")
else: # 休息状态 else: # 休息状态
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态 if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005)
time.sleep(0.005) return
return self.zmqServer.paradigmBuffer.getData(25)
self.zmqServer.paradigmBuffer.getData(25)
def decoder_MI(self): def decoder_MI(self):
'''模型训练''' '''模型训练'''
if self.train_started == False and all( if self.train_started == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练 self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机 self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
self.train_started = True self.train_started = True
self.trainData = np.array(self.trainData) self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) + 1 self.trainLabel = np.array(self.trainLabel)
# print('训练集:',np.shape(self.trainData), (self.trainLabel)) algo_log(f"MI开始训练训练集{np.shape(self.trainData)}标签shape{np.shape(self.trainLabel)}", level="DEBUG")
if save_train_data == 1:
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型 p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start() p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath, self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
@@ -347,7 +354,7 @@ class Decoder_main(threading.Thread):
try: try:
result = self.mp_result_queue.get_nowait() result = self.mp_result_queue.get_nowait()
if result['status'] == 'success': if result['status'] == 'success':
print("模型训练完成,加载新模型") algo_log("MI模型训练完成,加载新模型", level="DEBUG")
# 调用模型 # 调用模型
self.model = torch.load(self.modelPath, weights_only=False) self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval() self.model.eval()
@@ -360,45 +367,43 @@ class Decoder_main(threading.Thread):
self.load_model = True self.load_model = True
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机 self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
else: else:
print("训练失败:", result['msg']) algo_log("MI训练失败: " + result['msg'], level="DEBUG")
except Empty: except Empty:
pass # 还没完成 pass # 还没完成
except Exception as e: except Exception as e:
print('模型调用失败: ', e) algo_log("MI模型训练失败: " + str(e), level="DEBUG")
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.StartTrain: if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.currentLabel = self.zmqServer.currentLabel self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
self.zmqServer.StartTrain = False self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
self.interval_epoch[1] \ originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
+ self.zmqServer.event_inner_idx: algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.0001) time.sleep(0.0001)
return return
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
print('训练集:', np.shape(self.trainData))
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode: if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
@@ -406,7 +411,7 @@ class Decoder_main(threading.Thread):
time.sleep(0.0001) time.sleep(0.0001)
return return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx]) algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
start = time.time() start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:, data = data[:,
@@ -423,16 +428,15 @@ class Decoder_main(threading.Thread):
Cls = self.model(test_data) Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1] y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item())) self.plotLabel.append(int(y_pred.item()))
print('运动意图识别: ', y_pred) algo_log(f"MI运动意图识别: {y_pred}")
self.zmqServer.broadcast_message('paradigm', int(y_pred.item())) self.zmqServer.broadcast_message('result', int(y_pred.item()))
end = time.time() end = time.time()
print(f'发送给界面完成,耗时{end - start:.3f}s。') algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态 else: # 休息状态
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态 if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005)
time.sleep(0.005) return
return self.zmqServer.paradigmBuffer.getData(25)
self.zmqServer.paradigmBuffer.getData(25)
# def decoder_concentration(self): # def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict': # if self.zmqServer.state_mode == 'predict':

View File

@@ -34,7 +34,7 @@ cudnn.benchmark = True
cudnn.deterministic = True cudnn.deterministic = True
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
# writer = SummaryWriter('./TensorBoardX/') # writer = SummaryWriter('./TensorBoardX/')
from logs.log import algo_log
# Convolution module # Convolution module
# use conv to capture local features, instead of postion embedding. # use conv to capture local features, instead of postion embedding.
@@ -318,11 +318,7 @@ class ExP():
train_pred = torch.max(outputs, 1)[1] train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e, algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc)
self.log_write.write(str(e) + " " + str(acc) + "\n") self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1 num = num + 1
@@ -335,8 +331,8 @@ class ExP():
torch.save(self.model, model_path) torch.save(self.model, model_path)
averAcc = averAcc / num averAcc = averAcc / num
print('The average accuracy is:', averAcc) algo_log(f"The average accuracy is: {averAcc}", level="debug")
print('The best accuracy is:', bestAcc) algo_log(f"The best accuracy is: {bestAcc}", level="debug")
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
@@ -346,10 +342,10 @@ class ExP():
def onlineTrain(data_queue,result_queue): def onlineTrain(data_queue,result_queue):
import torch import torch
print(f"[DEBUG] torch.__version__ = {torch.__version__}") algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}") algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
if torch.cuda.is_available(): if torch.cuda.is_available():
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}") algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
try: try:
starttime = datetime.datetime.now() starttime = datetime.datetime.now()
@@ -366,12 +362,13 @@ def onlineTrain(data_queue,result_queue):
data = data_queue.get(timeout=30) data = data_queue.get(timeout=30)
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan'] all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
exp = ExP(n_chan) exp = ExP(n_chan)
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path) algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
print('THE BEST ACCURACY IS ' + str(bestAcc)) algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
endtime = datetime.datetime.now() endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime)) algo_log(f"train duration: {endtime - starttime}", level="debug")
# 将模型或参数传回 # 将模型或参数传回
result_queue.put({ result_queue.put({
@@ -387,7 +384,7 @@ def offlineTrain(all_data,all_label,modelPath):
# seed_n = np.random.randint(2025) # seed_n = np.random.randint(2025)
seed_n = 1877 seed_n = 1877
print('seed is ' + str(seed_n)) algo_log(f"seed is {seed_n}", level="debug")
random.seed(seed_n) random.seed(seed_n)
np.random.seed(seed_n) np.random.seed(seed_n)
torch.manual_seed(seed_n) torch.manual_seed(seed_n)
@@ -397,13 +394,12 @@ def offlineTrain(all_data,all_label,modelPath):
exp = ExP() exp = ExP()
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath) bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
print('THE BEST ACCURACY IS ' + str(bestAcc)) algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
endtime = datetime.datetime.now() endtime = datetime.datetime.now()
print('train duration: ',str(endtime - starttime)) algo_log(f"train duration: {endtime - starttime}", level="debug")
if __name__ == "__main__": if __name__ == "__main__":
print(time.asctime(time.localtime(time.time()))) algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")
print(time.asctime(time.localtime(time.time())))

View File

@@ -22,6 +22,7 @@ from einops import rearrange
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from torch.backends import cudnn from torch.backends import cudnn
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from logs.log import algo_log
# writer = SummaryWriter('./TensorBoardX/') # writer = SummaryWriter('./TensorBoardX/')
@@ -190,7 +191,7 @@ class ExP():
# 自动选择设备:有 GPU 用 GPU否则用 CPU # 自动选择设备:有 GPU 用 GPU否则用 CPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.device = torch.device("cpu") # self.device = torch.device("cpu")
print(f"Using device: {self.device}") algo_log(f"Using device: {self.device}", level="debug")
# 定义张量类型(不再强制使用 cuda # 定义张量类型(不再强制使用 cuda
self.Tensor = torch.FloatTensor self.Tensor = torch.FloatTensor

View File

@@ -13,9 +13,31 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
6. decoder class切换问题 6. decoder class切换问题
7. decoder_class切换时数据重置、各类参数重置 7. decoder_class切换时数据重置、各类参数重置
# realease log
- 2026年6月11日11:29:17 打包第一版包名runDecoder.dist_v0.0.0_beta_20260611.7z
- 2026年6月11日12:00:00 打包第二版包名runDecoder.dist_v0.0.0_beta_20260611.7z
- 修复上位机先发decoder_class, 后发open_impedence 带来decoder_main thread 阻塞问题
- 2026年6月12日15:05:47 runDecoder.dist_v0.0.2_beta_20260612
- 优化filter读数精度
# 常用命令 # 常用命令
source activate 3in1Py310 source activate 3in1Py310
python runDecoder.py python runDecoder.py
python datamock.py python datamock.py
python ZeroMQClient_mock.py python ZeroMQClient_mock.py
python filter_test.py
python upperHost_stimmock/MI_headless.py
# 打包命令
./nuitka_3in1_package.sh
# TODO
1. mvep是否要把list freq 开放到config
2. 滤波器参数 放到config文件
# debug log
## MI
Epoch采集完成|收到命令: {'method': 'train'|取出的
收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到

View File

@@ -12,16 +12,17 @@ from scipy.io import loadmat
from scipy.linalg import qr from scipy.linalg import qr
from scipy.signal import filtfilt, lfilter from scipy.signal import filtfilt, lfilter
# from numpy.linalg import _umath_linalg # from numpy.linalg import _umath_linalg
from logs.log import algo_log
class FbccaDw: class FbccaDw:
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method): def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
print('******************************************') algo_log('******************************************', level="debug")
print('parameter list') algo_log('parameter list',level="debug")
print('target:', num_target) algo_log(f"target: {num_target}", level="debug")
print('number of filter bank:', num_filter) algo_log(f"number of filter bank: {num_filter}", level="debug")
print('parameter:', parameter) algo_log(f"parameter: {parameter}", level="debug")
print('width:', width) algo_log(f"width: {width}", level="debug")
self.phase = 0 self.phase = 0
self.bandWidth = width self.bandWidth = width
self.winNum = winNum self.winNum = winNum
@@ -237,7 +238,7 @@ class FbccaDw:
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0]) dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
return np.asmatrix(dataFiltered) return np.asmatrix(dataFiltered)
except Exception: except Exception:
print(Exception) algo_log(f"Exception: {Exception}", level="debug")
''' '''
getDataQ getDataQ

View File

@@ -20,7 +20,7 @@ class Beta_Calculate():
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13))) alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8))) theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}") # print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
return beta_psd, alpha_psd, theta_psd return beta_psd, alpha_psd, theta_psd

View File

@@ -89,7 +89,8 @@ def zero_mq_client(server_address="tcp://127.0.0.1:8099"):
{"method": "train", "params": 1}, {"method": "train", "params": 1},
{"method": "rest", "params": 0}, {"method": "rest", "params": 0},
{"method": "predict", "params": 1}, {"method": "predict", "params": 1},
{"method": "getReport", "params": 0} {"method": "getReport", "params": 0},
{"method": "targetFreqs", "params": [11, 12, 13]}
] ]
# 打印消息集 # 打印消息集

View File

@@ -21,7 +21,7 @@ class ParadigmRingBuffer:
def appendBuffer(self, data): def appendBuffer(self, data):
if self.nUpdate == self.n_points: if self.nUpdate == self.n_points:
# raise Exception("Buffer is full") # raise Exception("Buffer is full")
algo_log("Buffer is full", record_once=True) algo_log("ParadigmRingBuffer is full", record_once=True)
n = data.shape[1] n = data.shape[1]

View File

@@ -5,95 +5,163 @@
import numpy as np import numpy as np
import time import time
import threading import threading
import queue
from scipy import signal from scipy import signal
from logs.log import algo_log from logs.log import algo_log
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Tools.beta_calculate import Beta_Calculate
class FilterRingBuffer: class FilterRingBuffer:
def __init__(self, n_chan, n_points): def __init__(self, n_chan, n_points):
"""
初始化纯数据环形缓存(线程安全)
:param n_chan: 通道数
:param n_points: 总缓存点数与paradigmRingBuffer参数完全一致
"""
self.n_chan = n_chan self.n_chan = n_chan
self.n_points = n_points self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64) self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置 self.current_ptr = 0
self.total_samples = 0 # 已写入总点数 self.total_samples = 0
self.lock = threading.Lock() # 线程安全锁 self.lock = threading.Lock() # 仅保护元数据
self.has_new_data = False
def appendBuffer(self, data): def appendBuffer(self, data):
""" n = data.shape[1]
追加数据到缓存与paradigmRingBuffer接口一致 if n == 0:
:param data: 输入数据shape=(n_chan, n_samples) return
"""
# 仅加锁读取/更新元数据
with self.lock: with self.lock:
n = data.shape[1] old_ptr = self.current_ptr
if n == 0: new_ptr = (old_ptr + n) % self.n_points
return new_total = min(self.total_samples + n, self.n_points)
self.has_new_data = True
# 环形写入逻辑:指针到末尾则绕回 # 数组写入(耗时操作,移出锁外)
write_end = self.current_ptr + n write_end = old_ptr + n
if write_end <= self.n_points: if write_end <= self.n_points:
self.buffer[:, self.current_ptr:write_end] = data self.buffer[:, old_ptr:write_end] = data
else: else:
split = self.n_points - self.current_ptr split = self.n_points - old_ptr
self.buffer[:, self.current_ptr:] = data[:, :split] self.buffer[:, old_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:] self.buffer[:, :write_end - self.n_points] = data[:, split:]
# 更新指针(取模保证环形)和计数(不超过缓存总长度) # 再次加锁更新最终元数据
self.current_ptr = write_end % self.n_points with self.lock:
self.total_samples = min(self.total_samples + n, self.n_points) self.current_ptr = new_ptr
self.total_samples = new_total
# ========== 新增:获取&清空新数据标记的方法 ==========
def check_and_clear_new_data(self):
"""检查是否有新数据,并一次性清空标记(消费后重置)"""
with self.lock:
flag = self.has_new_data
if flag:
self.has_new_data = False
return flag
def getData(self, count): def getData(self, count):
""" # 加锁获取最新元数据
从最新位置向前读取count个点环形读取
核心逻辑current_ptr是下一个写入位置 → 最新数据在current_ptr之前
:param count: 读取点数
:return: np.ndarray, shape=(n_chan, count)
"""
with self.lock: with self.lock:
count = min(count, self.total_samples) count = min(count, self.total_samples)
if count == 0: if count == 0:
return np.zeros((self.n_chan, 0)) return np.zeros((self.n_chan, 0))
# 环形读取end是当前写入指针最新数据的下一位start是end - count
end = self.current_ptr end = self.current_ptr
start = end - count start = end - count
if start >= 0:
return self.buffer[:, start:end].copy() # 数据读取、切片、拼接(无锁)
else: if start >= 0:
# 跨环形边界:前半部分从缓存末尾取,后半部分从开头取 res = self.buffer[:, start:end].copy()
part1 = self.buffer[:, start:] # start为负等价于n_points + start else:
part2 = self.buffer[:, :end] part1 = self.buffer[:, start:]
return np.concatenate((part1, part2), axis=1) part2 = self.buffer[:, :end]
res = np.concatenate((part1, part2), axis=1).copy()
return res
def get_latest_n_points(self, n): def get_latest_n_points(self, n):
"""
扩展方法获取最新的n个点不移动读指针用于滑动窗口
:param n: 点数
:return: np.ndarray, shape=(n_chan, n) | None数据不足时
"""
with self.lock: with self.lock:
if self.total_samples < n: if self.total_samples < n:
return None return None
return self.getData(n) return self.getData(n)
def GetDataLenCount(self): def GetDataLenCount(self):
"""获取当前缓存总点数(兼容原有接口)"""
with self.lock: with self.lock:
return self.total_samples return self.total_samples
def resetAllPara(self): def resetAllPara(self):
"""重置所有缓存和指针(兼容原有接口)"""
with self.lock: with self.lock:
self.buffer.fill(0.0) self.buffer.fill(0.0)
self.current_ptr = 0 self.current_ptr = 0
self.total_samples = 0 self.total_samples = 0
self.has_new_data = False # 重置时清空新数据标记
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现 # 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时
# -----------------------------------------------------------------------------
class BetaPsdCalculator(threading.Thread):
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
def __init__(self, fs=250, window_size=750):
super().__init__(daemon=True)
self.fs = fs
self.window_size = window_size
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
self._input_queue = queue.Queue(maxsize=2)
self._running = threading.Event()
self._running.set()
self._latest_beta = None
self._beta_lock = threading.Lock()
self.beta_broadcast_callback = None
def push_data(self, data):
"""供外部调用的线程安全数据推送接口"""
try:
self._input_queue.put_nowait(data)
except queue.Full:
try:
self._input_queue.get_nowait()
except queue.Empty:
pass
try:
self._input_queue.put_nowait(data)
except queue.Full:
pass
def get_latest_beta(self):
"""获取最新的 beta 值(线程安全)"""
with self._beta_lock:
return self._latest_beta
def run(self):
while self._running.is_set():
try:
data = self._input_queue.get(timeout=1.5)
if data is None:
break
try:
beta_psd, _, _ = self._beta_calc.calculate_all(
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
)
with self._beta_lock:
self._latest_beta = round(float(beta_psd), 3)
if self.beta_broadcast_callback is not None:
self.beta_broadcast_callback(self._latest_beta)
except Exception as e:
algo_log(f"Beta PSD 计算异常: {e}", level='error')
except queue.Empty:
pass
def stop(self):
"""停止计算线程"""
self._running.clear()
try:
self._input_queue.put_nowait(None)
except queue.Full:
pass
if self.is_alive():
self.join(timeout=2)
# -----------------------------------------------------------------------------
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread): class SlidingFilter(threading.Thread):
def __init__( def __init__(
@@ -122,24 +190,42 @@ class SlidingFilter(threading.Thread):
# 滤波结果回调(外部可注册,获取滤波后的数据) # 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None self.filter_result_callback = None
# beta 每秒触发计数200ms步长5次 = 1s
self._beta_step_counter = 0
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
self.slide_ready = False # 窗口是否已填满初始数据
# 预计算滤波器系数(仅执行一次) # 预计算滤波器系数(仅执行一次)
self._init_filters() self._init_filters()
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
def start(self):
"""同时启动 Beta 计算线程和滤波主线程"""
self._beta_thread.start()
super().start()
def set_beta_broadcast_callback(self, callback):
"""注册 Beta PSD 广播回调函数"""
self._beta_thread.beta_broadcast_callback = callback
def _init_filters(self): def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)""" """预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准 # 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate) self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位 # 0.5~45Hz带通FIR65阶线性相位
self.b_bp = signal.firwin( self.b_bp = signal.firwin(
numtaps=65, numtaps=65,
cutoff=[8/(self.srate/2), 30/(self.srate/2)], cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
pass_zero=False, pass_zero=False,
window='hamming' window='hamming'
) )
self.a_bp = np.array([1.0]) self.a_bp = np.array([1.0])
def _filter_window_data(self, window_data): def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回无边界效应的200ms数据""" """对3秒窗口数据执行滤波返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
# 零相位滤波(无延迟,无边界效应) # 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True) filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1) filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
@@ -150,40 +236,64 @@ class SlidingFilter(threading.Thread):
start_idx = self.window_size - 2 * self.step_size start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy() output_data = filtered[:, start_idx:end_idx].copy()
return output_data return output_data, filtered
def run(self): def run(self):
"""线程主逻辑精确200ms触发一次滤波""" """线程主逻辑精确200ms触发一次滤波"""
# 精确定时核心基于perf_counter计算下一次执行时间补偿sleep误差 interval = self.step_sec # 0.2s
interval = self.step_sec # 200ms = 0.2秒 # 以启动时刻为绝对时间基准(核心改动)
next_run_time = time.perf_counter() base_time = time.perf_counter()
frame_count = 0 # 帧计数器,用于对齐时序
while self.running.is_set(): while self.running.is_set():
# 1. 等待到下一次执行时间(精确定时) # 计算理论执行时刻:严格按帧序号 × 步长
expect_time = base_time + frame_count * interval
current_time = time.perf_counter() current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time) # 精确定时等待
next_run_time += interval # 补偿:下次执行时间基于上一次目标时间 if current_time < expect_time:
time.sleep(expect_time - current_time)
else: else:
# 若超时如滤波耗时超过200ms重置下一次时间避免累积误差 # 处理超时:仅告警,不重置基准(防止累积偏移
algo_log("滤波耗时超过200ms定时偏移", level='debug') algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
next_run_time = time.perf_counter() + interval
# 2. 执行滤波逻辑 frame_count += 1 # 帧序号自增,保证周期绝对稳定
if not self.ring_buffer.check_and_clear_new_data():
# 无新数据,不执行滤波、不发送数据
continue
# ========== 原有滤波逻辑 ==========
try: try:
# 获取最新的3秒窗口数据 if not self.slide_ready:
window_data = self.ring_buffer.get_latest_n_points(self.window_size) # 阶段1首次填满3s初始窗口
if window_data is None: full_data = self.ring_buffer.get_latest_n_points(self.window_size)
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug') if full_data is None:
continue algo_log("初始窗口数据不足", level='debug')
continue
self.slide_window = full_data
self.slide_ready = True
else:
# 阶段2正常滑动 → 取最新50个新点增量拼接
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
if new_step_data is None:
algo_log("滑动步长数据不足", level='debug')
continue
# 增量滑动丢弃前50点拼接新50点标准滑动窗口
self.slide_window = np.hstack([
self.slide_window[:, self.step_size:],
new_step_data
])
# 滤波并提取无边界效应的200ms数据 filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
filtered_data = self._filter_window_data(window_data)
# Beta PSD 每秒计算一次
self._beta_step_counter += 1
if self._beta_step_counter >= self._beta_steps_per_second:
self._beta_step_counter = 0
self._beta_thread.push_data(filtered_full[:2, :])
# 回调返回结果(外部可处理)
if self.filter_result_callback is not None: if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据 self.filter_result_callback(filtered_data)
except Exception as e: except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error') algo_log(f"滤波执行异常: {e}", level='error')
@@ -192,17 +302,11 @@ class SlidingFilter(threading.Thread):
self.filter_result_callback = callback self.filter_result_callback = callback
def stop(self): def stop(self):
"""停止滤波线程(安全版)""" """停止滤波线程和 Beta 计算线程"""
# 1. 先设置停止标志Event.clear()是线程安全的) self._beta_thread.stop()
self.running.clear() self.running.clear()
# 2. 核心修复只有线程已启动且正在运行时才调用join
if self.is_alive(): if self.is_alive():
# 等待线程正常退出最多1秒
self.join(timeout=1) self.join(timeout=1)
# 超时未退出时打印警告,便于排查问题
if self.is_alive(): if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING") algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
# 3. 无论线程是否启动,都打印停止日志
algo_log("滤波线程已停止") algo_log("滤波线程已停止")

View File

@@ -2,6 +2,7 @@
import ast import ast
import numpy as np import numpy as np
import threading import threading
import zmq
import json import json
import queue import queue
from typing import Dict from typing import Dict
@@ -13,14 +14,15 @@ from Zmq.filterProcess import FilterRingBuffer
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log from logs.log import algo_log
import zmq zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1'))
class zmqServer(threading.Thread): class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.device_info = device_info self.device_info = device_info
self.host = host self.host = zmqServer_host
self.cmd_port = cmd_port # 命令交互端口收JSON命令 + 返JSON结果 self.cmd_port = cmd_port # 命令交互端口收JSON命令 + 返JSON结果
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果 self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
self.running = False self.running = False
@@ -92,6 +94,7 @@ class zmqServer(threading.Thread):
self.pack_contain_event = False self.pack_contain_event = False
self.event_inner_idx = -1 self.event_inner_idx = -1
self.interval_inited = False self.interval_inited = False
self.last_epoch_finish_time = None
def reset_state(self): def reset_state(self):
"""清空采集器状态和缓存数据""" """清空采集器状态和缓存数据"""
@@ -105,21 +108,21 @@ class zmqServer(threading.Thread):
def interval_init(self, decoder_class): def interval_init(self, decoder_class):
if decoder_class == 'ssmvep': if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
self.train_epoch = [ self.train_epoch = [
int(self.interval_epoch[0]), int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
] ] # [50, 575]
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
elif decoder_class == 'mi': elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] #[125, 1125]
self.train_epoch = self.interval_epoch.copy() self.train_epoch = self.interval_epoch.copy()
self.latency = self.interval_epoch[1] // 5 self.latency = self.interval_epoch[1] // 5 #225
self.train_latency = self.latency self.train_latency = self.latency #225
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO") algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
self.count_events: Dict[str, int] = {} self.count_events: Dict[str, int] = {}
@@ -149,7 +152,8 @@ class zmqServer(threading.Thread):
msg = {'method': method, 'params': params} msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8') msg_bytes = json.dumps(msg).encode('utf-8')
algo_log(f"发送命令结果: {msg}", level="DEBUG") if msg['method'] != 'beta_psd':
algo_log(f"发送命令结果: {msg}", level="DEBUG")
# 广播到所有命令客户端 # 广播到所有命令客户端
for client_id in list(self.cmd_clients): for client_id in list(self.cmd_clients):
@@ -176,7 +180,7 @@ class zmqServer(threading.Thread):
# 转置为上位机需要的[50, 通道数]格式 # 转置为上位机需要的[50, 通道数]格式
filtered_data = filtered_data.T.astype(np.float64) filtered_data = filtered_data.T.astype(np.float64)
send_buf = filtered_data.tobytes() send_buf = filtered_data.tobytes()
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG") # algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
self.data_send_queue.put(send_buf) self.data_send_queue.put(send_buf)
def _process_data_send_queue(self): def _process_data_send_queue(self):
@@ -193,7 +197,7 @@ class zmqServer(threading.Thread):
b"", b"",
send_buf send_buf
]) ])
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG") algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG", record_once=True)
except Exception as e: except Exception as e:
algo_log(f"发送滤波数据失败: {e}", level="ERROR") algo_log(f"发送滤波数据失败: {e}", level="ERROR")
@@ -222,6 +226,9 @@ class zmqServer(threading.Thread):
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR") algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"}) self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
return return
except Exception as e:
algo_log(f"_handle_cmd_message exception: {e}", level="ERROR")
return
algo_log(f"收到命令: {message}", level="INFO") algo_log(f"收到命令: {message}", level="INFO")
method = message.get("method") method = message.get("method")
@@ -246,8 +253,20 @@ class zmqServer(threading.Thread):
self.decoder_switch = True self.decoder_switch = True
elif method == "train": elif method == "train":
self.state_mode = 'train' self.state_mode = 'train'
self.StartTrain = True resp = {
self.currentLabel = params "method": "train_response",
"params": {
"code": 200,
"message": "ok"
}
}
try:
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG")
except Exception as e:
algo_log(f"train 命令回复失败: {e}", level="ERROR")
return
elif method == "predict": elif method == "predict":
self.state_mode = 'predict' self.state_mode = 'predict'
if params == 1: #开始解码 if params == 1: #开始解码
@@ -255,6 +274,22 @@ class zmqServer(threading.Thread):
elif params == 2: #停止解码 elif params == 2: #停止解码
self.IsExitApp = True self.IsExitApp = True
self.running = False self.running = False
resp = {
"method": "predict_response",
"params": {
"code": 200,
"message": "ok"
}
}
try:
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
algo_log(f"predict 命令已即时回复客户端 {ident}", level="DEBUG")
except Exception as e:
algo_log(f"predict 命令回复失败: {e}", level="ERROR")
return
elif method == "rest": elif method == "rest":
self.state_mode = 'rest' self.state_mode = 'rest'
elif method == "impedance": elif method == "impedance":
@@ -268,7 +303,7 @@ class zmqServer(threading.Thread):
# -------------------------- 数据端口消息处理 -------------------------- # -------------------------- 数据端口消息处理 --------------------------
def _handle_data_message(self, frames): def _handle_data_message(self, frames):
"""处理8100端口二进制脑电数据消息""" """处理8100端口二进制脑电数据消息"""
algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=False) algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True)
# 然后再进行解析 # 然后再进行解析
if len(frames) == 4: if len(frames) == 4:
# 你的上位机格式 # 你的上位机格式
@@ -276,6 +311,8 @@ class zmqServer(threading.Thread):
elif len(frames) == 3: elif len(frames) == 3:
# 标准格式 # 标准格式
ident, empty_sep, data_bytes = frames[:3] ident, empty_sep, data_bytes = frames[:3]
elif len(frames) == 2:
ident, data_bytes = frames[:2]
else: else:
return return
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份) # 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
@@ -286,7 +323,7 @@ class zmqServer(threading.Thread):
algo_log(f"新数据客户端连接成功: {ident}", level="INFO") algo_log(f"新数据客户端连接成功: {ident}", level="INFO")
try: try:
# 精确长度校验 # 精确长度校验
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * np.dtype(np.float64).itemsize
if len(data_bytes) != EXPECTED_BYTES: if len(data_bytes) != EXPECTED_BYTES:
algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR") algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
return return
@@ -307,8 +344,22 @@ class zmqServer(threading.Thread):
if self.pack_contain_event: if self.pack_contain_event:
self.paradigmBuffer.resetAllPara() self.paradigmBuffer.resetAllPara()
self.paradigmBuffer.appendBuffer(data_np) self.paradigmBuffer.appendBuffer(data_np)
if self.epoch_finished: if self.epoch_finished:
algo_log('Epoch采集完成: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG") now = datetime.datetime.now()
time_diff_str = ""
# 计算与上一次Epoch完成的时间差
if self.last_epoch_finish_time is not None:
# 时间差 单位保留3位小数
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
# 拼接日志,增加时间差信息
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
algo_log(log_msg, level="DEBUG")
# 更新上一次Epoch完成时间为当前时间
self.last_epoch_finish_time = now
else: else:
self.paradigmBuffer.appendBuffer(data_np) self.paradigmBuffer.appendBuffer(data_np)
@@ -322,9 +373,9 @@ class zmqServer(threading.Thread):
def detect_event(self, samples): def detect_event(self, samples):
self.pack_contain_event = False self.pack_contain_event = False
# 第65通道为事件通道 # 第65通道为事件通道
events = np.array(samples[-2])[0].tolist() events = np.array(samples[-2], dtype=np.int32).tolist()
for idx, event in enumerate(events): for idx, event in enumerate(events):
if int(event) in self.events: if event in self.events:
new_key = "".join( new_key = "".join(
[ [
str(event), str(event),
@@ -332,11 +383,13 @@ class zmqServer(threading.Thread):
-%H-%M-%S"), -%H-%M-%S"),
] ]
) )
self.currentLabel = event
if event == self.predict_event: if event == self.predict_event:
self.count_events[new_key] = self.latency + 1 self.count_events[new_key] = self.latency + 1
else: else:
self.count_events[new_key] = self.train_latency + 1 self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = idx self.event_inner_idx = idx
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
self.pack_contain_event = True self.pack_contain_event = True
# 倒计时并清理过期事件 # 倒计时并清理过期事件
@@ -356,7 +409,7 @@ class zmqServer(threading.Thread):
# -------------------------- 主循环 -------------------------- # -------------------------- 主循环 --------------------------
def run(self): def run(self):
self.running = True self.running = True
algo_log(f"ZMQ服务器启动成功 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
try: try:
while self.running: while self.running:
@@ -372,13 +425,18 @@ class zmqServer(threading.Thread):
frames = self.cmd_socket.recv_multipart() frames = self.cmd_socket.recv_multipart()
self._handle_cmd_message(frames) self._handle_cmd_message(frames)
# 处理8100数据端口消息 # 处理8100数据端口消息(排空积压,消除标签延迟)
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
frames = self.data_socket.recv_multipart() while True:
self._handle_data_message(frames) try:
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
self._handle_data_message(frames)
except zmq.Again:
break
except Exception as e: except Exception as e:
algo_log(f"服务器主循环异常: {e}", level="ERROR") algo_log(f"服务器主循环异常: {str(e)}", level="ERROR")
return
finally: finally:
self.running = False self.running = False
# 优雅关闭所有资源 # 优雅关闭所有资源

View File

@@ -15,9 +15,25 @@ Audio_device = 0
Rest_time = 2 Rest_time = 2
Upper_Host = 127.0.0.1 Upper_Host = 127.0.0.1
Upper_Port = 8088 Upper_Port = 8088
Decoder_Host = 127.0.0.1
Decoder_Port = 8099
Serial_port = COM44 Serial_port = COM44
algo_log_level = DEBUG save_train_data = 0
console_output = 1 zmqServer_host = 127.0.0.1
[algo_log]
# ========== 文件日志配置 ==========
file_log_enable = true
file_log_level = DEBUG
log_path = exe
retention_days = 3
# ========== 控制台/黑框配置 ==========
console_enable = true
console_show_window = true
console_log_level = DEBUG
; 64 导设备配置 ; 64 导设备配置
[device_type_1] [device_type_1]

View File

@@ -1,6 +1,7 @@
import zmq import zmq
import numpy as np import numpy as np
import time import time
import threading
from datetime import datetime from datetime import datetime
# ========== 参数配置 ========== # ========== 参数配置 ==========
@@ -11,6 +12,7 @@ EEG_FREQ = 10 # EEG 正弦波频率 Hz
EEG_AMP = 100.0 # EEG 幅值 100μV EEG_AMP = 100.0 # EEG 幅值 100μV
LABEL_INTERVAL = 5 # 标签间隔秒数 LABEL_INTERVAL = 5 # 标签间隔秒数
SERVER_ADDR = 'tcp://127.0.0.1:8100' SERVER_ADDR = 'tcp://127.0.0.1:8100'
LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签命令
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms # 发送间隔: 每包 5 采样点 / 250Hz = 20ms
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
@@ -65,6 +67,60 @@ def main():
sock.connect(SERVER_ADDR) sock.connect(SERVER_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}") print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
# ========== 上位机标签命令监听 ==========
# 使用线程安全的队列接收来自 ssmvep_main.py 的标签命令
# 标签值: 1 (train 0), 2 (train 1), 99 (predict)
pending_label = [None] # [label_value or None]
label_lock = threading.Lock()
label_cmd_sock = ctx.socket(zmq.PULL)
label_cmd_sock.bind(LABEL_CMD_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听绑定到 {LABEL_CMD_ADDR}")
stop_recv = threading.Event()
def label_cmd_thread():
"""监听来自上位机范式的标签命令,写入 pending_label"""
while not stop_recv.is_set():
try:
msg = label_cmd_sock.recv_string(zmq.NOBLOCK)
label_val = int(msg)
with label_lock:
pending_label[0] = label_val
ts = datetime.now().strftime('%H:%M:%S')
label_name = {1: 'train_0', 2: 'train_1', 99: 'predict'}.get(label_val, str(label_val))
print(f"[{ts}] 收到标签命令: {label_name} -> label={label_val}")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[label_cmd_thread] 错误: {e}")
time.sleep(0.01)
label_thread = threading.Thread(target=label_cmd_thread, daemon=True)
label_thread.start()
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听线程已启动")
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
recv_count = [0]
def consumer_thread():
"""消费线程:阻塞 recv丢弃收到的数据仅用于清空 ROUTER 发送队列"""
while not stop_recv.is_set():
try:
frames = sock.recv_multipart(zmq.NOBLOCK)
recv_count[0] += 1
# 收到的格式: [identity, '', filtered_data_bytes]
if recv_count[0] % 500 == 0:
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已丢弃 {recv_count[0]} 帧滤波数据")
except zmq.Again:
time.sleep(0.01)
except zmq.error.Again: # 兼容旧版
time.sleep(0.01)
consumer = threading.Thread(target=consumer_thread, daemon=True)
consumer.start()
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已启动daemon")
global_sample_idx = 0 # 全局采样点计数器 global_sample_idx = 0 # 全局采样点计数器
label_type = 1 # 当前标签类型: 1 或 2 label_type = 1 # 当前标签类型: 1 或 2
label1_count = 0 # label=1 的序号计数器 label1_count = 0 # label=1 的序号计数器
@@ -74,7 +130,7 @@ def main():
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...") print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms") print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV") print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
print(f" 标签: {LABEL_INTERVAL}s 末尾采样点触发 | label 1/2 交替") print(f" 标签: 来自上位机范式命令 (train_0=1, train_1=2, predict=99)")
print("-" * 50) print("-" * 50)
try: try:
@@ -84,31 +140,23 @@ def main():
# 构建当前包 # 构建当前包
packet = build_packet(global_sample_idx) packet = build_packet(global_sample_idx)
# 检查是否需要放置标签 # 检查是否有来自上位机范式的挂起标签命令
if should_send_label(global_sample_idx): with label_lock:
if label_type == 1: ext_label = pending_label[0]
label1_count += 1 if ext_label is not None:
label_value = 1 pending_label[0] = None
label_number = label1_count
else:
label2_count += 1
label_value = 2
label_number = label2_count
# 标签放在当前包最后一个采样点(索引 4
packet[4, 64] = label_value
packet[4, 65] = label_number
if ext_label is not None:
# 将标签写入当前包所有5个采样点的第65通道 (index 64)
# 覆盖全部采样点确保 event_inner_idx 无论落在哪个位置都能被正确检测
packet[:, 64] = float(ext_label)
ts = datetime.now().strftime('%H:%M:%S') ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 标签触发: label={label_value}, 序号={label_number} " print(f"[{ts}] 标签: label={ext_label} -> ch64[all 5 samples] (global_sample_idx={global_sample_idx})")
f"(global_sample_idx={global_sample_idx})")
# 交替标签类型 # 发送: multipart 2帧 ['', data]
label_type = 2 if label_type == 1 else 1 # 使用标准格式ROUTER 会自动附加 ZMQ 分配的客户端身份
# 发送: multipart 3帧 [identity, '', data]
# 使用标准格式3帧ROUTER 会自动附加 ZMQ 分配的客户端身份
sock.send_multipart([ sock.send_multipart([
b'',
packet.tobytes() packet.tobytes()
]) ])
@@ -129,6 +177,9 @@ def main():
except KeyboardInterrupt: except KeyboardInterrupt:
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count}") print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count}")
finally: finally:
stop_recv.set()
consumer.join(timeout=2)
label_cmd_sock.close()
sock.close() sock.close()
ctx.term() ctx.term()

421
filter_test.py Normal file
View File

@@ -0,0 +1,421 @@
# -*- coding: utf-8 -*-
"""
脑电滤波服务 8100端口测试工具【统计逻辑专项优化版】
优化点:
1. 5秒预热(250个发包),预热结束后才启动丢包/数据统计
2. 业务比例0.02s发1包200ms收1包 → 每 10 个发包对应 1 个回包
3. 通道校验:发送(5,66) 仅对比前64通道接收(50,64)全通道比对
4. 区分:全局总包数 / 有效统计区间包数、理论收包数、实际收包数、丢包数、丢包率
5. 新增64通道整体数据均值/极值比对,校验数据有效性
通信规范send_multipart([client_id, b"", data_buf]) 三帧报文,服务端 recv_multipart 长度=3
"""
import sys
import time
import threading
import logging
import traceback
from collections import deque
import numpy as np
import zmq
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# ===================== 全局前置修复Matplotlib中文字体 & 负号显示 =====================
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
plt.rcParams["axes.unicode_minus"] = False
# ===================== 【1. 全局业务固定参数(核心统计规则)】 =====================
# ZMQ 服务端配置
ZMQ_SERVER_IP = "127.0.0.1"
ZMQ_SERVER_PORT = 8100
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
POLL_TIMEOUT = 10 # Poll轮询超时(ms)
# 时序 & 统计核心规则(严格对齐现场业务)
SEND_INTERVAL = 0.02 # 上位机发包间隔20ms/包
RECV_INTERVAL = 0.2 # 服务端回包间隔200ms/包
PREHEAT_SECONDS = 5.0 # 滤波缓存预热时长5秒
# 计算:预热需要的发包总数 = 预热时长 / 单包发送间隔
PREHEAT_SEND_PACKS = int(PREHEAT_SECONDS / SEND_INTERVAL) # 5 / 0.02 = 250 包
# 收发包比例每多少个发包对应1个回包
PACK_RATIO = int(RECV_INTERVAL / SEND_INTERVAL) # 0.2 / 0.02 = 10
# 数据报文形状
PKG_SEND_SHAPE = (5, 66) # 发送包 (点数, 总通道)
PKG_RECV_SHAPE = (50, 64) # 回包 (点数, 有效脑电通道)
SAMPLE_RATE = 250
# 通道定义对比仅使用前64路脑电通道
CH_EEG_VALID = 64 # 共同对比通道数0~63
CH_EVENT = 64
CH_RESERVED = 65
# ZMQ 三帧报文固定字段
CLIENT_ID = b"test_client_001"
EMPTY_FRAME = b""
# 仿真信号配置
TARGET_CHANNEL = 0
SIGNAL_FREQ_LIST = [13]
SIGNAL_AMP = 1.8
NOISE_GAUSSIAN_AMP = 0.4
NOISE_POWER50_AMP = 0.3
EVENT_LABEL_VAL = 1
RESERVED_VAL = 0.0
# 可视化配置
MAX_PLOT_POINTS = 800
PLOT_REFRESH_INTERVAL = 80
FFT_N_POINTS = 256
PLOT_X_LIMIT_FREQ = (0, 60)
# 运行控制
MAX_RUN_SECONDS = None
ENABLE_RECONNECT = True
PRINT_STAT_INTERVAL = 5.0
# ===================== 【2. 全局变量 + 统计结构体(重构统计逻辑)】 =====================
g_running = threading.Event()
g_running.set()
data_lock = threading.Lock()
# 绘图缓冲区
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
# ===================== 全新统计变量(区分预热/正式统计) =====================
stat = {
# 全局总包数(包含预热包)
"total_send": 0,
"total_recv": 0,
# 有效统计区间预热250包之后
"valid_send": 0, # 有效发包数
"valid_recv": 0, # 有效收包数
"theo_recv": 0, # 理论应收到包数 = valid_send // PACK_RATIO
# 运行时间
"start_time": time.perf_counter(),
"last_print_time": time.perf_counter(),
# 数据校验缓存保存最新一包原始64通道数据用于和回包比对
"latest_raw_64ch": None
}
# ===================== 【3. 日志配置】 =====================
def init_logger():
log_format = "%(asctime)s | %(levelname)-8s | %(message)s"
logging.basicConfig(
level=logging.INFO,
format=log_format,
datefmt="%Y-%m-%d %H:%M:%S"
)
return logging.getLogger("FilterTest")
logger = init_logger()
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
"""生成单包 (5,66) 仿真数据"""
n_point, n_chan = PKG_SEND_SHAPE
base_t = pkt_idx * n_point / SAMPLE_RATE
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
data = np.zeros((n_point, n_chan), dtype=np.float64)
# 64路脑电信号
for ch in range(CH_EEG_VALID):
sig = 0.0
for freq in SIGNAL_FREQ_LIST:
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
# sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
# sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
data[:, ch] = sig
# 事件通道、保留通道
data[:, CH_EVENT] = EVENT_LABEL_VAL
data[:, CH_RESERVED] = RESERVED_VAL
return data
# ===================== 【5. ZMQ 核心IO线程单连接+Poller保留原有通信逻辑】 =====================
def zmq_io_thread():
context = zmq.Context()
pkt_index = 0
send_interval = SEND_INTERVAL
logger.info(f"滤波预热配置:{PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 个发包后开始统计")
logger.info(f"收发比例:每 {PACK_RATIO} 个发包 → 1 个滤波回包")
while g_running.is_set():
try:
sock = context.socket(zmq.DEALER)
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
next_send_ts = time.perf_counter()
while g_running.is_set():
# 全局运行时长限制
if MAX_RUN_SECONDS is not None:
run_sec = time.perf_counter() - stat["start_time"]
if run_sec > MAX_RUN_SECONDS:
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s停止任务")
return
# ========== 1. 轮询接收服务端回包 ==========
socks_ready = dict(poller.poll(POLL_TIMEOUT))
if sock in socks_ready:
frames = sock.recv_multipart()
if not frames:
continue
recv_bytes = frames[-1]
if not recv_bytes:
continue
# 解析回包 (50,64)
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
if filt_data.size != expect_size:
logger.warning(f"回包长度异常:实际{filt_data.size},预期{expect_size}")
continue
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
# 全局收包计数
stat["total_recv"] += 1
# 仅预热完成后,计入有效统计收包
if stat["total_send"] > PREHEAT_SEND_PACKS:
stat["valid_recv"] += 1
# 写入绘图缓冲区
with data_lock:
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
# ---------- 新增64通道数据比对发包前64通道 <-> 回包64通道 ----------
raw_64ch = stat["latest_raw_64ch"]
if raw_64ch is not None:
raw_mean = np.mean(raw_64ch)
filt_mean = np.mean(filt_data)
raw_amp = np.max(np.abs(raw_64ch))
filt_amp = np.max(np.abs(filt_data))
logger.debug(
f"【通道数据比对】原始64通道均值:{raw_mean:.4f} 幅值:{raw_amp:.4f} | "
f"滤波后均值:{filt_mean:.4f} 幅值:{filt_amp:.4f}"
)
# ========== 2. 精准定时发送数据包 ==========
current_ts = time.perf_counter()
if current_ts >= next_send_ts:
# 生成(5,66)仿真包
pkt_data = generate_eeg_packet(pkt_index)
pkt_index += 1
send_buf = pkt_data.tobytes()
# 标准三帧Multipart发送
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
# ---------- 发包计数逻辑(核心优化:预热区分) ----------
stat["total_send"] += 1
# 预热完成后,计入有效发包
if stat["total_send"] > PREHEAT_SEND_PACKS:
stat["valid_send"] += 1
# 计算理论应收包数
stat["theo_recv"] = stat["valid_send"] // PACK_RATIO
# 缓存当前包前64通道用于后续数据比对
stat["latest_raw_64ch"] = pkt_data[:, :CH_EEG_VALID]
# 绘图缓冲区(单通道波形)
with data_lock:
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
# 更新下一次发包时间
next_send_ts += send_interval
# ========== 3. 定时打印统计信息(区分预热/正式统计) ==========
now = time.perf_counter()
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
run_sec = now - stat["start_time"]
total_send = stat["total_send"]
total_recv = stat["total_recv"]
# 分支1仍在预热阶段
if total_send <= PREHEAT_SEND_PACKS:
remain = PREHEAT_SEND_PACKS - total_send
logger.info(
f"[预热中] 运行:{run_sec:.1f}s | 已发包:{total_send}/{PREHEAT_SEND_PACKS} | "
f"剩余预热包:{remain} | 暂不统计丢包"
)
# 分支2预热完成进入正式统计
else:
v_send = stat["valid_send"]
v_recv = stat["valid_recv"]
t_recv = stat["theo_recv"]
loss_cnt = t_recv - v_recv
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
logger.info(
f"[正式统计] 运行:{run_sec:.1f}s | "
f"全局总包: 发{total_send}/收{total_recv} | "
f"有效区间: 发{v_send}/应收{t_recv}/实收{v_recv} | "
f"丢包数:{loss_cnt} | 丢包率:{loss_rate:.2f}%"
)
stat["last_print_time"] = now
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
continue
logger.warning(f"ZMQ 连接异常: {e}")
sock.close()
poller.unregister(sock)
if not ENABLE_RECONNECT:
break
logger.info("500ms 后尝试重连...")
time.sleep(0.5)
except Exception as e:
logger.error(f"IO线程未知异常:\n{traceback.format_exc()}")
break
context.term()
logger.info("ZMQ IO 线程已退出")
# ===================== 【6. 可视化绘图(无改动)】 =====================
def init_plot():
fig = plt.figure(figsize=(14, 9))
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
ax1 = plt.subplot(2, 2, 1)
ax1.set_title("原始输入波形 (含噪声+工频)")
ax1.set_ylabel("幅值")
ax1.grid(True, alpha=0.3)
line_raw, = ax1.plot([], [], color="#1f77b4", linewidth=1)
ax2 = plt.subplot(2, 2, 2)
ax2.set_title("滤波后输出波形")
ax2.set_ylabel("幅值")
ax2.grid(True, alpha=0.3)
line_filt, = ax2.plot([], [], color="#d62728", linewidth=1)
ax3 = plt.subplot(2, 2, 3)
ax3.set_title("原始信号频谱")
ax3.set_xlabel("频率 (Hz)")
ax3.set_xlim(*PLOT_X_LIMIT_FREQ)
ax3.grid(True, alpha=0.3)
line_raw_fft, = ax3.plot([], [], color="#1f77b4")
ax4 = plt.subplot(2, 2, 4)
ax4.set_title("滤波后信号频谱")
ax4.set_xlabel("频率 (Hz)")
ax4.set_xlim(*PLOT_X_LIMIT_FREQ)
ax4.grid(True, alpha=0.3)
line_filt_fft, = ax4.plot([], [], color="#d62728")
plt.tight_layout(rect=[0, 0, 1, 0.96])
return fig, [line_raw, line_filt, line_raw_fft, line_filt_fft], [ax1, ax2, ax3, ax4]
def update_plot(frame, lines, axes):
line_raw, line_filt, line_raw_fft, line_filt_fft = lines
ax1, ax2, ax3, ax4 = axes
with data_lock:
raw_data = list(raw_data_buf)
filt_data = list(filt_data_buf)
if raw_data:
x_raw = np.arange(len(raw_data))
line_raw.set_data(x_raw, raw_data)
ax1.relim()
ax1.autoscale_view()
if filt_data:
x_filt = np.arange(len(filt_data))
line_filt.set_data(x_filt, filt_data)
ax2.relim()
ax2.autoscale_view()
def calc_fft(sig, n_fft):
if len(sig) < n_fft:
return [], []
win = np.hanning(n_fft)
sig_win = sig[-n_fft:] * win
fft_vals = np.fft.fft(sig_win)
fft_amp = np.abs(fft_vals)[:n_fft//2]
freq = np.fft.fftfreq(n_fft, 1/SAMPLE_RATE)[:n_fft//2]
return freq, fft_amp
freq_raw, amp_raw = calc_fft(raw_data, FFT_N_POINTS)
freq_filt, amp_filt = calc_fft(filt_data, FFT_N_POINTS)
line_raw_fft.set_data(freq_raw, amp_raw)
line_filt_fft.set_data(freq_filt, amp_filt)
ax3.relim()
ax3.autoscale_view(scaley=True)
ax4.relim()
ax4.autoscale_view(scaley=True)
return lines
# ===================== 【7. 资源释放 & 最终汇总统计】 =====================
def clean_resource():
g_running.clear()
logger.info("开始停止所有线程...")
time.sleep(0.3)
plt.close("all")
logger.info("资源释放完成")
def main():
logger.info("=" * 70)
logger.info("脑电滤波测试客户端【统计逻辑优化版】启动")
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
logger.info(f"发包: {PKG_SEND_SHAPE}({SEND_INTERVAL*1000:.0f}ms) | 回包: {PKG_RECV_SHAPE}({RECV_INTERVAL*1000:.0f}ms)")
logger.info(f"预热规则: {PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 包后开启统计")
logger.info(f"收发比例: 每 {PACK_RATIO} 个发包对应 1 个回包")
logger.info("=" * 70)
# 启动ZMQ收发线程
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
io_thread.start()
# 启动可视化
fig, lines, axes = init_plot()
ani = FuncAnimation(
fig, update_plot,
fargs=(lines, axes),
interval=PLOT_REFRESH_INTERVAL,
blit=True,
cache_frame_data=False
)
try:
plt.show()
except KeyboardInterrupt:
logger.info("收到 Ctrl+C 中断信号,准备退出")
finally:
# 输出最终完整汇总报表
run_total = time.perf_counter() - stat["start_time"]
total_send = stat["total_send"]
total_recv = stat["total_recv"]
v_send = stat["valid_send"]
v_recv = stat["valid_recv"]
t_recv = stat["theo_recv"]
loss_cnt = t_recv - v_recv
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
logger.info(f"\n{'='*50} 最终运行汇总 {'='*50}")
logger.info(f"总运行时长: {run_total:.1f} s")
logger.info(f"【全局总包数】发送: {total_send} | 接收: {total_recv}")
logger.info(f"【有效统计区间(跳过预热{PREHEAT_SEND_PACKS}包)】")
logger.info(f" 有效发包: {v_send} | 理论应收包: {t_recv} | 实际收包: {v_recv}")
logger.info(f" 总丢包数: {loss_cnt} | 整体丢包率: {loss_rate:.2f} %")
logger.info(f"{'='*106}")
clean_resource()
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -1,62 +1,156 @@
import os import os
from datetime import datetime import sys
from pathlib import Path
from datetime import datetime, timedelta
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import inspect # 新增导入 import inspect
try:
import win32gui
import win32con
WIN32_AVAILABLE = True
except ImportError:
WIN32_AVAILABLE = False
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
# ===================== 新增:获取 EXE 同级目录 =====================
def get_app_root():
"""获取 runDecoder.exe 所在的真实根目录(兼容 onefile / standalone"""
if getattr(sys, 'frozen', False):
# Nuitka / PyInstaller 打包后走这里
app_path = sys.executable
else:
# 本地源码运行时,取当前脚本目录
app_path = os.path.abspath(__file__)
return os.path.dirname(app_path)
console_output = IniRead('system', 'console_output', '1') # 程序根目录exe 同级)
log_level = IniRead('system', 'algo_log_level', 'INFO') APP_ROOT = Path(get_app_root())
# 日志文件夹名exe 同级下 logs 目录
DEFAULT_LOG_DIR = APP_ROOT / "logs"
# ===================== 读取 [algo_log] 配置 =====================
# 文件日志
FILE_LOG_ENABLE = IniRead("algo_log", "file_log_enable", "true").lower() == "true"
FILE_LOG_LEVEL = IniRead("algo_log", "file_log_level", "DEBUG").upper()
# 优先级:配置文件 > 默认exe同级logs
CFG_LOG_PATH = IniRead("algo_log", "log_path", "").strip()
if CFG_LOG_PATH == "exe":
LOG_DIR = DEFAULT_LOG_DIR
else:
LOG_DIR = Path(CFG_LOG_PATH)
LOG_RETENTION_DAYS = int(IniRead("algo_log", "retention_days", 3))
# 控制台日志 + 黑框控制
CONSOLE_ENABLE = IniRead("algo_log", "console_enable", "true").lower() == "true"
CONSOLE_SHOW_WINDOW = IniRead("algo_log", "console_show_window", "true").lower() == "true"
CONSOLE_LOG_LEVEL = IniRead("algo_log", "console_log_level", "INFO").upper()
# ===================== 全局常量与缓存 =====================
log_once_cache = set() log_once_cache = set()
# 缓存已经创建过的logger避免重复创建handler
logger_cache = {} logger_cache = {}
LOG_FILE_PREFIX = 'algo_log_'
# 确保日志目录存在
LOG_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR_STR = str(LOG_DIR) + "\\"
# 日志格式
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
# 日志级别映射
LEVEL_MAP = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"FATAL": logging.FATAL
}
FILE_LOG_LEVEL_INT = LEVEL_MAP.get(FILE_LOG_LEVEL, logging.INFO)
CONSOLE_LOG_LEVEL_INT = LEVEL_MAP.get(CONSOLE_LOG_LEVEL, logging.INFO)
# ===================== Windows 控制台黑框显示/隐藏 =====================
def control_console_window():
if not sys.platform.startswith("win") or not WIN32_AVAILABLE:
return
try:
hwnd = win32gui.GetForegroundWindow()
if CONSOLE_SHOW_WINDOW:
win32gui.ShowWindow(hwnd, win32con.SW_SHOW)
else:
win32gui.ShowWindow(hwnd, win32con.SW_HIDE)
except Exception:
pass
control_console_window()
# ===================== 清理过期日志 =====================
def clean_old_logs():
try:
if not LOG_DIR.exists():
return
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
for filename in os.listdir(LOG_DIR):
if not (filename.startswith(LOG_FILE_PREFIX) and filename.endswith('.log')):
continue
date_str = filename[len(LOG_FILE_PREFIX):-4]
try:
file_date = datetime.strptime(date_str, '%Y-%m-%d')
if file_date < expire_date:
file_path = LOG_DIR / filename
os.remove(file_path)
except ValueError:
continue
except Exception:
pass
# ===================== 初始化日志器 =====================
def init_module_logger(logger_name): def init_module_logger(logger_name):
log_dir = './logs/'
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
# 已创建直接返回
if logger_name in logger_cache: if logger_name in logger_cache:
return logger_cache[logger_name] return logger_cache[logger_name]
clean_old_logs()
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.setLevel(log_level) logger.setLevel(logging.DEBUG)
if logger.handlers: if logger.handlers:
logger_cache[logger_name] = logger logger_cache[logger_name] = logger
return logger return logger
file_handler = RotatingFileHandler( formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
log_file,
maxBytes=10*1024*1024,
backupCount=10,
encoding='utf-8'
)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_output: # 文件日志
console_handler = logging.StreamHandler() if FILE_LOG_ENABLE:
current_date = datetime.now().strftime("%Y-%m-%d")
log_file = LOG_DIR / f"{LOG_FILE_PREFIX}{current_date}.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=10 * 1024 * 1024,
backupCount=10,
encoding='utf-8'
)
file_handler.setFormatter(formatter)
file_handler.setLevel(FILE_LOG_LEVEL_INT)
logger.addHandler(file_handler)
# 控制台日志
if CONSOLE_ENABLE:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
console_handler.setLevel(CONSOLE_LOG_LEVEL_INT)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger_cache[logger_name] = logger logger_cache[logger_name] = logger
return logger return logger
# ===================== 对外日志入口函数 =====================
def algo_log(content, level="INFO", record_once=False): def algo_log(content, level="INFO", record_once=False):
# 向上回溯1层栈拿到调用algo_log的代码文件信息 frame = inspect.currentframe()
frame = inspect.currentframe().f_back if frame:
file_path = frame.f_code.co_filename frame = frame.f_back.f_back
# 提取py文件名不带后缀/带后缀自选) file_name = os.path.basename(frame.f_code.co_filename) if frame else "unknown"
file_name = os.path.basename(file_path) # 例zmqServer.py
# file_name = os.path.splitext(os.path.basename(file_path))[0] # 例zmqServer
logger = init_module_logger(file_name) logger = init_module_logger(file_name)
@@ -67,13 +161,12 @@ def algo_log(content, level="INFO", record_once=False):
log_once_cache.add(log_key) log_once_cache.add(log_key)
level_upper = level.upper() level_upper = level.upper()
if level_upper == "DEBUG": log_func_map = {
logger.debug(content) "DEBUG": logger.debug,
elif level_upper == "WARNING": "INFO": logger.info,
logger.warning(content) "WARNING": logger.warning,
elif level_upper == "ERROR": "ERROR": logger.error,
logger.error(content) "FATAL": logger.fatal
elif level_upper == "FATAL": }
logger.fatal(content) log_func = log_func_map.get(level_upper, logger.info)
else: log_func(content)
logger.info(content)

54
nuitka_3in1_package.sh Normal file
View File

@@ -0,0 +1,54 @@
#!/bin/bash
# Git Bash 中文 UTF-8 兼容配置(通用版,无报错)
export LC_ALL=en_US.UTF-8
export LANG=en_US.UTF-8
echo "========================"
echo "Nuitka 打包脚本 - 优化稳定版"
echo "适配PyTorch2.0.0 + CUDA11.7 + 脑电解码项目"
echo "========================"
# ===================== 自定义配置区 =====================
PY_FILE="runDecoder.py" # 主程序文件
OUT_DIR="dist_nuitka" # 输出文件夹
MODEL_DIR="online_Models" # 模型文件夹
# ========================================================
# 检查主文件是否存在
if [ ! -f "${PY_FILE}" ]; then
echo "错误:未找到主文件 ${PY_FILE},请检查路径!"
read -n 1 -s -r -p "按任意键退出"
exit 1
fi
echo "开始打包:${PY_FILE}"
echo "输出目录:${OUT_DIR}"
# Nuitka 核心打包命令(无错误、无冗余、全依赖)
python -m nuitka \
--standalone \
--msvc=latest \
--module-parameter=torch-disable-jit=yes \
--enable-plugin=no-qt \
--include-package=numpy \
--include-module=numpy.core._multiarray_umath \
--include-package=scipy \
--no-deployment-flag=self-execution \
--include-data-dir="${MODEL_DIR}=${MODEL_DIR}" \
--output-dir="${OUT_DIR}" \
--remove-output \
"${PY_FILE}"
# 打包结果判断
if [ $? -eq 0 ]; then
echo -e "\n========================"
echo "✅ 打包成功!"
echo "📦 产物路径:${OUT_DIR}/${PY_FILE%.py}.exe"
echo "========================"
else
echo -e "\n❌ 打包失败!"
fi
# Git Bash 兼容的暂停
read -n 1 -s -r -p "按任意键退出..."
echo

View File

@@ -1,252 +0,0 @@
0 0.5
1 0.5
2 0.375
3 0.5
4 0.4375
5 0.375
6 0.5
7 0.5
8 0.375
9 0.375
10 0.375
11 0.375
12 0.5
13 0.5625
14 0.5625
15 0.5
16 0.5
17 0.5
18 0.5
19 0.5625
20 0.4375
21 0.5
22 0.5
23 0.375
24 0.375
25 0.375
26 0.375
27 0.375
28 0.3125
29 0.375
30 0.5625
31 0.5
32 0.5
33 0.5625
34 0.5625
35 0.3125
36 0.3125
37 0.3125
38 0.375
39 0.5625
40 0.3125
41 0.5625
42 0.3125
43 0.375
44 0.5625
45 0.5
46 0.375
47 0.375
48 0.3125
49 0.375
50 0.375
51 0.5
52 0.5625
53 0.375
54 0.5625
55 0.5625
56 0.375
57 0.375
58 0.375
59 0.5
60 0.3125
61 0.375
62 0.375
63 0.375
64 0.375
65 0.375
66 0.3125
67 0.375
68 0.5625
69 0.5625
70 0.5625
71 0.5
72 0.5625
73 0.375
74 0.375
75 0.375
76 0.375
77 0.375
78 0.5
79 0.375
80 0.375
81 0.5
82 0.375
83 0.375
84 0.375
85 0.375
86 0.3125
87 0.375
88 0.375
89 0.5
90 0.375
91 0.4375
92 0.3125
93 0.3125
94 0.375
95 0.375
96 0.375
97 0.375
98 0.3125
99 0.4375
100 0.375
101 0.375
102 0.375
103 0.3125
104 0.5625
105 0.5
106 0.5625
107 0.5625
108 0.5
109 0.3125
110 0.5625
111 0.5625
112 0.5
113 0.3125
114 0.5
115 0.3125
116 0.375
117 0.3125
118 0.3125
119 0.3125
120 0.3125
121 0.375
122 0.375
123 0.375
124 0.375
125 0.3125
126 0.375
127 0.375
128 0.375
129 0.375
130 0.5625
131 0.375
132 0.5
133 0.3125
134 0.3125
135 0.3125
136 0.375
137 0.5
138 0.3125
139 0.375
140 0.3125
141 0.3125
142 0.3125
143 0.5625
144 0.3125
145 0.375
146 0.5
147 0.5
148 0.375
149 0.4375
150 0.5
151 0.3125
152 0.375
153 0.375
154 0.375
155 0.3125
156 0.375
157 0.4375
158 0.4375
159 0.375
160 0.375
161 0.3125
162 0.375
163 0.375
164 0.375
165 0.3125
166 0.3125
167 0.3125
168 0.375
169 0.3125
170 0.3125
171 0.3125
172 0.375
173 0.3125
174 0.3125
175 0.5
176 0.3125
177 0.375
178 0.375
179 0.3125
180 0.3125
181 0.3125
182 0.3125
183 0.5625
184 0.5625
185 0.3125
186 0.5
187 0.5
188 0.5625
189 0.5
190 0.5625
191 0.5625
192 0.5625
193 0.5
194 0.5
195 0.5625
196 0.5625
197 0.5625
198 0.5625
199 0.5
200 0.5625
201 0.5625
202 0.375
203 0.375
204 0.375
205 0.375
206 0.375
207 0.5
208 0.5
209 0.5625
210 0.5625
211 0.5625
212 0.3125
213 0.5
214 0.5
215 0.5625
216 0.5
217 0.5
218 0.5
219 0.5625
220 0.5
221 0.4375
222 0.5
223 0.5
224 0.4375
225 0.5
226 0.4375
227 0.5
228 0.5
229 0.375
230 0.375
231 0.3125
232 0.375
233 0.375
234 0.375
235 0.5625
236 0.5625
237 0.5625
238 0.5625
239 0.5625
240 0.5
241 0.5
242 0.5
243 0.5625
244 0.5625
245 0.375
246 0.375
247 0.375
248 0.3125
249 0.375
The average accuracy is: 0.42675
The best accuracy is: 0.5625

52
requirements.txt Normal file
View File

@@ -0,0 +1,52 @@
Bottleneck==1.4.2
brotlicffi==1.2.0.0
certifi==2026.5.20
cffi==2.0.0
charset-normalizer==3.4.4
contourpy==1.3.2
cycler==0.12.1
einops==0.8.2
filelock==3.20.3
fonttools==4.63.0
gmpy2==2.2.2
idna==3.11
Jinja2==3.1.6
joblib==1.5.3
kiwisolver==1.5.0
MarkupSafe==3.0.2
matplotlib==3.10.9
mkl_fft==1.3.11
mkl_random==1.2.8
mkl-service==2.5.2
mpmath==1.3.0
networkx==3.4.2
Nuitka==4.1.1
numexpr==2.14.1
numpy==1.24.3
packaging==26.0
pandas==2.3.3
pillow==12.2.0
pip==26.0.1
pycparser==3.0
pyparsing==3.3.2
pyserial==3.5
PySocks==1.7.1
python-dateutil==2.9.0.post0
pytz==2026.1.post1
pyzmq==27.1.0
requests==2.33.1
scikit-learn==1.7.1
scipy==1.15.3
setuptools==82.0.1
six==1.17.0
sympy==1.14.0
threadpoolctl==3.5.0
torch==2.0.0
torchaudio==2.0.0
torchsummary==1.5.1
torchvision==0.15.0
typing_extensions==4.15.0
tzdata==2026.2
urllib3==2.7.0
wheel==0.46.3
win_inet_pton==1.1.0

View File

@@ -1,7 +1,7 @@
import matplotlib # import matplotlib
matplotlib.use('Agg') # matplotlib.use('Agg')
import argparse # import argparse
import sys # import sys
import time import time
from Decoder import Decoder_main from Decoder import Decoder_main
from PubLibrary.RunOnce import is_program_running from PubLibrary.RunOnce import is_program_running

View File

@@ -0,0 +1,306 @@
"""
MI_headless.py
无界面版 MI 运动想象范式通讯流程模拟脚本。
复现 MI_main.py 的完整指令序列train 0/1, rest, predict, saveData
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
启动顺序:
1. runDecoder.py
2. datamock.py
3. MI_headless.py
"""
import sys
import os
import json
import time
import threading
import zmq
import numpy as np
import ast
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from PubLibrary.InifileHelper import IniRead
personname = 'demo'
session = '01'
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
# ========== ZMQ 结果接收服务 ==========
class ZmqResultServer(threading.Thread):
def __init__(self, port=8088):
threading.Thread.__init__(self)
self.port = port
self.running = True
self.energy = 0
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
self.ChoosenNum = -1
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
self.daemon = True
self.trial_idx = 0
def run(self):
print(f"[Server] UpperHost_Server listening on {self.port}")
while self.running:
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
if len(frames) < 3:
continue
message = json.loads(frames[2].decode('utf-8'))
method = message.get('method')
params = message.get('params')
if method == 'energy':
self.energy = params
elif method == 'paradigm':
self.paradigm = params
print(f"[Server] paradigm -> {params}")
elif method == 'result':
self.ChoosenNum = params
self.trial_idx += 1
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[Server] error: {e}")
def stop(self):
self.running = False
self.socket.close()
self.context.term()
# ========== ZMQ 命令发送客户端 ==========
class ZmqCmdClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
# PUSH socket 用于向 datamock.py 发送标签命令
self._label_sock = self.context.socket(zmq.PUSH)
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
def connect(self):
self.socket.connect(f"tcp://{self.host}:{self.port}")
print(f"[Client] connected to {self.host}:{self.port}")
def start_recv_thread(self, result_server):
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
self._result_server = result_server
self._stop_recv = threading.Event()
def _recv_loop():
while not self._stop_recv.is_set():
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
# DEALER 收到的格式: [b'', json_bytes]
data_bytes = frames[-1]
message = json.loads(data_bytes.decode('utf-8'))
method = message.get('method')
params = message.get('params')
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] [CmdClient] recv: {method}={params}")
if method == 'paradigm':
self._result_server.paradigm = params
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
elif method == 'result':
self._result_server.ChoosenNum = params
self._result_server.trial_idx += 1
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
elif method == 'energy':
self._result_server.energy = params
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[CmdClient recv] error: {e}")
time.sleep(0.01)
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
self._recv_thread.start()
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
def stop_recv_thread(self):
if hasattr(self, '_stop_recv'):
self._stop_recv.set()
def _send_label(self, label_value):
"""向 datamock.py 发送标签命令"""
try:
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
except Exception as e:
print(f"[Client] label send error: {e}")
def send_data(self, method, params):
msg = {'method': method, 'params': params}
try:
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] send_data: {method}={params}")
# 根据 train/predict 命令向 datamock 发送标签
if method == 'train':
if params == 0:
self._send_label(1)
print(f"[Label] train 0 -> datamock label=1")
elif params == 1:
self._send_label(2)
print(f"[Label] train 1 -> datamock label=2")
elif method == 'predict':
self._send_label(99)
print(f"[Label] predict -> datamock label=99")
except Exception as e:
print(f"[Client] send error: {e}")
# ========== 主流程 ==========
def run_headless():
server = ZmqResultServer(port=8088)
server.start()
_dh = str(IniRead('system', 'Decoder_Host'))
_dp = int(IniRead('system', 'Decoder_Port'))
client = ZmqCmdClient(_dh, _dp)
client.connect()
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
time.sleep(1) # 等待连接建立
client.send_data('decoderClass', 'mi')
time.sleep(4) # 等待 zmqServer 排空启动积压包datamock 提前连接会积压 ~3s 数据)
# MI_IntervalEpoch = [0.5, 4.5]trial时长 = 4.5-0.5 = 4.0s
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
_trial_sec = float(_mi_iv[1] - _mi_iv[0]) # 4.0s
_margin = 1.0
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
# train_latency = 225包MI中 train_latency == latency
# 在 train_time 后需再等 epoch_wait 秒decoder 才能完成 epoch 采集
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
# 更直接的计算latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225225*0.02 = 4.5s
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
# predict epoch wait与 train 相同MI中 latency == train_latency
predict_epoch_wait = epoch_wait # 4.5s
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
rest_time = float(IniRead('system', 'Rest_time'))
num_blocks = int(IniRead('system', 'Num_blocks'))
num_trials = int(IniRead('system', 'Num_trials'))
trained = 0
Num_Total = 0
Num_Success = 0
user_choice = []
print("=" * 50)
print("[Headless] 开始运行 MI 通讯流程(无界面)")
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
print("=" * 50)
try:
while True:
# -------- 个体校准阶段 --------
print("\n[Phase] 个体校准阶段 (paradigm=0)")
client.send_data('rest', 0)
time.sleep(1)
while server.paradigm == 0:
# 左侧 MI 刺激train 0label=1
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
client.send_data('rest', 0)
time.sleep(0.5) # ding 提示后等待
client.send_data('train', 0)
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(1.0) # 类间休息
# 空闲态样本采集train 1label=2
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
client.send_data('train', 1)
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(1.0) # 类间休息
# 个体校准阶段结束
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
trained = 0
time.sleep(1)
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
while server.paradigm == 2:
print("[Phase] 等待模型训练完成 ...")
time.sleep(0.5)
# -------- 康复训练阶段 --------
while server.paradigm == 1:
print("\n[Phase] 康复训练阶段 (paradigm=1)")
for block_idx in range(num_blocks):
print(f"\n [Block {block_idx+1}/{num_blocks}]")
time.sleep(10) # 每轮开始前等待
for trial_idx in range(num_trials):
print(f" [Trial {trial_idx+1}/{num_trials}]")
time.sleep(0.5) # ding 提示
server.ChoosenNum = -1
# 开始预测
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
client.send_data('predict', 1)
t_start = time.perf_counter()
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
if server.ChoosenNum >= 0:
Num_Total += 1
user_choice.append(server.ChoosenNum)
if server.ChoosenNum == 0:
Num_Success += 1
rest_time = right_rehabilitation
elif server.ChoosenNum == 1:
rest_time = fault_rehabilitation
break
time.sleep(0.02)
trained += 1
client.send_data('rest', 0)
time.sleep(0.5)
time.sleep(rest_time)
server.ChoosenNum = -1
# 训练结束
print("\n[Phase] 康复训练结束")
break # 退出康复训练循环
# 统计结果
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
print(f"[Result] user_choice={user_choice}")
break # 完成一个完整流程后退出
except KeyboardInterrupt:
print("\n[Headless] 用户中断")
finally:
client.send_data('predict', 2) # 关闭系统
client.send_data('saveData', 0)
server.stop()
print("[Headless] 已发送关闭指令,退出。")
if __name__ == '__main__':
run_headless()

View File

@@ -0,0 +1,301 @@
"""
ssmvep_headless.py
无界面版 SSMVEP 范式通讯流程模拟脚本。
复现 ssmvep_main.py 的完整指令序列train 0/1/2, rest, predict, saveData
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
启动顺序:
1. runDecoder.py
2. datamock.py
3. ssmvep_headless.py
"""
import sys
import os
import json
import time
import threading
import zmq
import numpy as np
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from PubLibrary.InifileHelper import IniRead
personname = 'demo'
session = '01'
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
# ========== ZMQ 结果接收服务 ==========
class ZmqResultServer(threading.Thread):
def __init__(self, port=8088):
threading.Thread.__init__(self)
self.port = port
self.running = True
self.energy = 0
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
self.ChoosenNum = -1
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
self.daemon = True
self.trial_idx = 0
def run(self):
print(f"[Server] UpperHost_Server listening on {self.port}")
while self.running:
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
if len(frames) < 3:
continue
message = json.loads(frames[2].decode('utf-8'))
method = message.get('method')
params = message.get('params')
if method == 'energy':
self.energy = params
elif method == 'paradigm':
self.paradigm = params
print(f"[Server] paradigm -> {params}")
elif method == 'result':
self.ChoosenNum = params
self.trial_idx += 1
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[Server] error: {e}")
def stop(self):
self.running = False
self.socket.close()
self.context.term()
# ========== ZMQ 命令发送客户端 ==========
class ZmqCmdClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
# PUSH socket 用于向 datamock.py 发送标签命令
self._label_sock = self.context.socket(zmq.PUSH)
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
def connect(self):
self.socket.connect(f"tcp://{self.host}:{self.port}")
print(f"[Client] connected to {self.host}:{self.port}")
def start_recv_thread(self, result_server):
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
self._result_server = result_server
self._stop_recv = threading.Event()
def _recv_loop():
while not self._stop_recv.is_set():
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
# DEALER 收到的格式: [b'', json_bytes]
data_bytes = frames[-1]
message = json.loads(data_bytes.decode('utf-8'))
method = message.get('method')
params = message.get('params')
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] [CmdClient] recv: {method}={params}")
if method == 'paradigm':
self._result_server.paradigm = params
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
elif method == 'result':
self._result_server.ChoosenNum = params
self._result_server.trial_idx += 1
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
elif method == 'energy':
self._result_server.energy = params
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[CmdClient recv] error: {e}")
time.sleep(0.01)
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
self._recv_thread.start()
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
def stop_recv_thread(self):
if hasattr(self, '_stop_recv'):
self._stop_recv.set()
def _send_label(self, label_value):
"""向 datamock.py 发送标签命令"""
try:
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
except Exception as e:
print(f"[Client] label send error: {e}")
def send_data(self, method, params):
msg = {'method': method, 'params': params}
try:
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] send_data: {method}={params}")
# 根据 train/predict 命令向 datamock 发送标签
if method == 'train':
if params == 0:
self._send_label(1)
print(f"[Label] train 0 -> datamock label=1")
elif params == 1:
self._send_label(2)
print(f"[Label] train 1 -> datamock label=2")
elif method == 'predict':
self._send_label(99)
print(f"[Label] predict -> datamock label=99")
except Exception as e:
print(f"[Client] send error: {e}")
# ========== 主流程 ==========
def run_headless():
server = ZmqResultServer(port=8088)
server.start()
_dh = str(IniRead('system', 'Decoder_Host'))
_dp = int(IniRead('system', 'Decoder_Port'))
client = ZmqCmdClient(_dh, _dp)
client.connect()
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
time.sleep(1) # 等待连接建立
client.send_data('decoderClass', 'ssmvep')
train_time = 2.5 # 每轮训练刺激时长 (s)
test_time = 2.5 # 每轮测试刺激时长 (s)
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
rest_time = float(IniRead('system', 'Rest_time'))
num_blocks = int(IniRead('system', 'Num_blocks'))
num_trials = int(IniRead('system', 'Num_trials'))
position = [0, 1]
truePos_seq = position * int(num_trials / len(position))
truePos_seq = np.random.permutation(truePos_seq).tolist()
user_choice = []
os.makedirs('EEGFiles', exist_ok=True)
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
seq_info = {
'position': position,
'sequence': truePos_seq,
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open(seq_file_path, 'w', encoding='utf-8') as f:
json.dump(seq_info, f, ensure_ascii=False, indent=2)
trained = 0
Num_Total = 0
Num_Success = 0
print("=" * 50)
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
print(f" train_time={train_time}s, test_time={test_time}s")
print("=" * 50)
try:
while True:
# -------- 个体校准阶段 --------
print("\n[Phase] 个体校准阶段 (paradigm=0)")
client.send_data('rest', 0)
time.sleep(1)
# epoch完成需要的额外等待时间train_latency=120包×20ms=2.4s
# 在train_time后需再等epoch_wait秒decoder才能完成epoch采集并取出数据
epoch_wait = 2.4 # 秒与train_latency对应
while server.paradigm == 0:
# 左腿刺激
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
client.send_data('train', 0)
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
# 右腿刺激
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
client.send_data('train', 1)
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(max(0, fault_rehabilitation - epoch_wait))
# 个体校准阶段结束
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
trained = 0
time.sleep(1)
# -------- 康复训练阶段 --------
while server.paradigm == 1:
print("\n[Phase] 康复训练阶段 (paradigm=1)")
for block_idx in range(num_blocks):
print(f"\n [Block {block_idx+1}/{num_blocks}]")
time.sleep(10) # 每轮开始前等待
for trial_idx in range(num_trials):
true_position = truePos_seq[trial_idx]
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
time.sleep(0.5) # 提示 + 叮声
server.ChoosenNum = -1
# 开始测试
# predict epoch latency = 115包×20ms = 2.3s需额外等待epoch完成
predict_epoch_wait = 2.3 # 秒与predict latency=115包对应
client.send_data('predict', 1)
t_start = time.perf_counter()
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
if server.ChoosenNum >= 0:
Num_Total += 1
user_choice.append(server.ChoosenNum)
if server.ChoosenNum in [0, 1]:
Num_Success += 1
rest_time = right_rehabilitation
break
time.sleep(0.02)
trained += 1
client.send_data('rest', 0)
time.sleep(0.5)
time.sleep(rest_time)
server.ChoosenNum = -1
# 训练结束
print("\n[Phase] 康复训练结束")
break # 退出康复训练循环
# 统计结果
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
expected_seq = truePos_seq * num_blocks
min_len = min(len(user_choice), len(expected_seq))
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
true_accuracy = same_count / min_len if min_len > 0 else 0
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
break # 完成一个完整流程后退出
except KeyboardInterrupt:
print("\n[Headless] 用户中断")
finally:
client.send_data('predict', 2) # 关闭系统
client.send_data('saveData', 0)
server.stop()
print("[Headless] 已发送关闭指令,退出。")
if __name__ == '__main__':
run_headless()

View File

@@ -0,0 +1,364 @@
import time
from psychopy import visual, core, logging # import some libraries from PsychoPy
import random
from datetime import datetime
# LAB STREAMING LAYER1
from pylsl import StreamInfo, StreamOutlet
from psychopy import event
import numpy as np
from DecoderDW.Server import TCPServer
from DecoderDW.Client import TCPClient
# import subprocess
# ----------------------
# constants
# size of the window
WINWIDTH = 1920
WINHEIGHT = 1080
REFRESH_RATE = 144
def get_keypress():
keys = event.getKeys()
if keys:
return keys[0]
else:
return None
def shutdown(win,client):
client.send_data('saveData', 0)
client.send_data('predict',2)
win.close()
core.quit()
# end of configuration
# ----------------------
def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5):
"""
生成方波序列
参数:
frequency (float): 频率Hz
sampling_rate (int): 采样率Hz应与屏幕刷新率一致
duration (float): 时长(秒)
返回:
square_wave (list): 方波序列
"""
# 计算总点数
n_points = int(duration * sampling_rate)
# 生成时间序列
time = np.linspace(0, duration, n_points, endpoint=False)
# 生成正弦波数据
sin_wave = np.sin(2 * np.pi * frequency * time)
# 生成方波数据
square_wave = np.where(sin_wave >= 0, 1, 0)
return square_wave.tolist()
# 启动一个进程,不等待其完成
import os
if __name__ == "__main__":
# ----------------------------------------------------------------------------------
# main window settings
main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False,
gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7))
print('starting 1')
# Set up LabStreamingLayer stream.
info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string',
source_id='psychopy_stimuli_001')
outlet = StreamOutlet(info) # Broadcast the stream.
imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg')
txtStim1 = visual.TextStim(win=main_win, text='', font='SimHei', height=80, color='black', units='pix', bold=True,
italic=False, pos=(-600, 30))
imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg')
txtStim2 = visual.TextStim(win=main_win, text='', font='SimHei', height=80, color='black', units='pix', bold=True,
italic=False, pos=(0, 30))
imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg')
txtStim3 = visual.TextStim(win=main_win, text='', font='SimHei', height=80, color='black', units='pix', bold=True,
italic=False, pos=(600, 30))
imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg')
txtStim4 = visual.TextStim(win=main_win, text='', font='SimHei', height=80, color='black', units='pix', bold=True,
italic=False, pos=(-600, -470))
imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg')
txtStim5 = visual.TextStim(win=main_win, text='', font='SimHei', height=80, color='black', units='pix', bold=True,
italic=False, pos=(0, -470))
imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg')
txtStim6 = visual.TextStim(win=main_win, text='', font='SimHei', height=80, color='black', units='pix', bold=True,
italic=False, pos=(600, -470))
imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg')
imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg')
imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg')
imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg')
imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg')
imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg')
frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30]
# 生成方波数据
square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5)
square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5)
square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5)
square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5)
square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5)
square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5)
# 创建刺激对象列表,便于管理
image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6]
txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6]
square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15]
time.sleep(2)
# grating.color = 'black'
server = TCPServer()
server.start()
client = TCPClient('127.0.0.1', 8099)
client.connect()
print('Connected decoder_main')
# client.send_data('impedance', 1)
# time.sleep(20)
# client.send_data('impedance', 2)
client.send_data('targetFreqs', frequencies) # 使用frequencies变量确保与刺激频率一致
time.sleep(1)
# 开启全程数据保存到 EEGFiles
client.send_data('saveData',1)
# client.send_data('impedance',1)
# 实验参数
repeats = 3
seq_freq = frequencies * repeats
seq_freq = np.random.permutation(seq_freq).tolist()
num_trials = len(seq_freq) # 总试验次数, 6*6=36
trial_count = 0
# 在线解码精度计算
online_results = [] # 存储每个trial的解码结果
correct_predictions = 0 # 正确预测计数
# 保存序列信息
seq_info = {
'total_trials': num_trials,
'frequencies': frequencies,
'sequence': seq_freq,
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
# 保存序列信息到文件
import json
seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
with open(seq_file_path, 'a', encoding='utf-8') as f:
json.dump(seq_info, f, ensure_ascii=False, indent=2)
#========================Trials Started======================#
while trial_count < num_trials:
# 从序列中获取当前试验的目标频率
target_freq = seq_freq[trial_count]
target_freq_index = frequencies.index(target_freq)
print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})')
# Stage 1: Cue Stage
# print('Cue Stage: The target frequency is in Red')
client.send_data('setLabelAndTrialInfo', {
'label': 0,
'trial_info': {
'trial': trial_count + 1,
'phase': 'cue',
'target_freq': target_freq
}
})
for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示
key_press = get_keypress()
if key_press in ['q']:
shutdown(main_win, client)
# 显示所有刺激,目标刺激为红色
for i, stim in enumerate(image_stims):
if i == target_freq_index:
# 目标刺激显示红色
if i == 0:
imageStim1red.draw()
elif i == 1:
imageStim2red.draw()
elif i == 2:
imageStim3red.draw()
elif i == 3:
imageStim4red.draw()
elif i == 4:
imageStim5red.draw()
elif i == 5:
imageStim6red.draw()
else:
# 其他刺激显示正常颜色
stim.draw()
main_win.flip()
# Stage 2: Flanker Stimulus
# print('Flanker Stage: flank all frequencies')
client.send_data('predict', 1)
client.send_data('setLabelAndTrialInfo', {
'label': target_freq_index + 1, # 设置目标频率标签 这里+1是因为0代表不记录数据
'trial_info': {
'trial': trial_count + 1, # trial 从0开始
'phase': 'stimulus',
'target_freq': target_freq
}
})
outlet.push_sample(['S 1'])
for frameN in range(6 * REFRESH_RATE): # 6秒刺激
key_press = get_keypress()
if key_press in ['q']:
shutdown(main_win, client)
# 所有频率按照方波闪烁
if square_wave_9[frameN % len(square_wave_9)] == 1:
imageStim1.draw()
if square_wave_11[frameN % len(square_wave_11)] == 1:
imageStim2.draw()
if square_wave_12[frameN % len(square_wave_12)] == 1:
imageStim3.draw()
if square_wave_13[frameN % len(square_wave_13)] == 1:
imageStim4.draw()
if square_wave_14[frameN % len(square_wave_14)] == 1:
imageStim5.draw()
if square_wave_15[frameN % len(square_wave_15)] == 1:
imageStim6.draw()
main_win.flip()
if server.ChoosenNum != -1:
break
# 记录在线解码结果
predicted_freq_index = server.ChoosenNum # 解码结果
predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1
# 判断解码是否正确
is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False
if is_correct:
correct_predictions += 1
# 记录trial结果
trial_result = {
'trial': trial_count + 1,
'target_freq': target_freq,
'target_freq_index': target_freq_index,
'predicted_freq': predicted_freq,
'predicted_freq_index': predicted_freq_index,
'is_correct': is_correct,
'status': 'Success' if predicted_freq_index != -1 else 'Failed'
}
online_results.append(trial_result)
# 打印当前trial结果
status_symbol = "" if is_correct else ""
if predicted_freq_index == -1:
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}')
else:
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}')
# Stage 3: Decoding Feedback
outlet.push_sample(['S 2'])
client.send_data('setLabelAndTrialInfo', {
'label': 0, # 反馈阶段标签为0
'trial_info': {
'trial': trial_count + 1,
'phase': 'feedback',
'target_freq': target_freq
}
})
# print('反馈阶段: 显示解码结果')
for frameN in range(1 * REFRESH_RATE): # 1秒反馈
key_press = get_keypress()
if key_press in ['q']:
shutdown(main_win, client)
# 显示所有刺激但不闪烁
for stim in image_stims:
stim.draw()
# 显示解码结果
if server.ChoosenNum == 0:
txtStim1.draw()
elif server.ChoosenNum == 1:
txtStim2.draw()
elif server.ChoosenNum == 2:
txtStim3.draw()
elif server.ChoosenNum == 3:
txtStim4.draw()
elif server.ChoosenNum == 4:
txtStim5.draw()
elif server.ChoosenNum == 5:
txtStim6.draw()
main_win.flip()
server.ChoosenNum = -1
trial_count += 1
# 计算总体在线解码精度
total_trials = len(online_results)
successful_trials = len([r for r in online_results if r['status'] == 'Success'])
failed_trials = len([r for r in online_results if r['status'] == 'Failed'])
overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0
# Print Accuracy
print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})")
# 按频率分析准确率
print(f"\n=== 按频率分析准确率 ===")
freq_accuracy = {}
for result in online_results:
freq = result['target_freq']
if freq not in freq_accuracy:
freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0}
freq_accuracy[freq]['total'] += 1
if result['status'] == 'Failed':
freq_accuracy[freq]['failed'] += 1
elif result['is_correct']:
freq_accuracy[freq]['correct'] += 1
print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}")
print("-" * 40)
for freq in sorted(freq_accuracy.keys()):
stats = freq_accuracy[freq]
accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}")
# 保存在线解码结果到文件
online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
online_summary = {
'total_trials': total_trials,
'successful_trials': successful_trials,
'failed_trials': failed_trials,
'correct_predictions': correct_predictions,
'overall_accuracy': overall_accuracy,
# 'freq_accuracy': freq_accuracy,
'trial_results': online_results,
# 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open(online_results_file, 'w', encoding='utf-8') as f:
json.dump(online_summary, f, ensure_ascii=False, indent=2)
client.send_data('predict',2) # 关闭系统
main_win.close()

304
verify_datamock.py Normal file
View File

@@ -0,0 +1,304 @@
"""
datamock 验证脚本(模拟算法端)
作为 ZMQ ROUTER 监听 8100 端口,等待 datamock.py 连接并验证数据流
运行顺序:
第一步: python verify_datamock.py (先启动,监听 8100)
第二步: python datamock.py (后启动,连接 8100)
"""
import zmq
import numpy as np
import time
import sys
import matplotlib
matplotlib.use('TkAgg')
# 在导入 pyplot 之前确保 Tkinter 正确初始化
try:
import tkinter as tk
root = tk.Tk()
root.withdraw() # 隐藏主窗口,我们只需要它的事件循环
except Exception as e:
print(f"[WARN] Tkinter 初始化警告: {e}")
import matplotlib.pyplot as plt
from datetime import datetime
# ===== 可视化参数 =====
PLOT_WINDOW_SEC = 2.0 # 滑动窗口时长(秒)
PLOT_CHANNELS = [0, 1, 2, 3] # 要显示的 EEG 通道索引
SERVER_ADDR = 'tcp://127.0.0.1:8100'
FS = 250
N_SAMPLES_PER_PKT = 5
N_CHAN = 66
EEG_FREQ = 10
EEG_AMP = 100.0 # EEG 幅值 100μV峰值
EEG_AMP_MEAN = EEG_AMP * 2 / np.pi # 正弦波 |mean| ≈ 63.7μV
EEG_AMP_TOLERANCE = 1.5 # 幅值容差倍数
LABEL_INTERVAL = 5
FFT_SAMPLES = 250 # 做一次 FFT 需要的采样点数1s数据
EXPECTED_BYTES = N_SAMPLES_PER_PKT * N_CHAN * 4 # 1320 bytes (5*66*4)
def validate_fft(samples):
"""对 Ch0 数据做 FFT返回峰值频率"""
freqs = np.fft.rfftfreq(FFT_SAMPLES, d=1 / FS)
fft_mag = np.abs(np.fft.rfft(samples))
peak_idx = np.argmax(fft_mag[1:]) + 1 # 跳过 DC
return freqs[peak_idx], fft_mag, freqs
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.ROUTER)
sock.bind(SERVER_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ ROUTER 绑定 {SERVER_ADDR},等待 datamock.py 连接...\n")
# ===== 初始化交互式绘图 =====
plt.ion() # 开启交互模式
fig = plt.figure(figsize=(14, 10))
fig.suptitle('EEG Data Monitor (Real-time)', fontsize=14)
# 使用 GridSpec 进行布局
from matplotlib.gridspec import GridSpec
gs = GridSpec(len(PLOT_CHANNELS) + 2, 1, figure=fig, hspace=0.3)
axes = []
lines_eeg = []
for i, ch in enumerate(PLOT_CHANNELS):
ax = fig.add_subplot(gs[i])
axes.append(ax)
ax.set_ylabel(f'Ch{ch} (μV)', fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim(-150, 150)
line, = ax.plot([], [], lw=0.8)
lines_eeg.append(line)
ax.set_title(f'EEG Channel {ch}', fontsize=9)
# 标签通道子图 (Ch64 - 标签值)
ax_label = fig.add_subplot(gs[len(PLOT_CHANNELS)])
axes.append(ax_label)
ax_label.set_ylabel('Label Value', fontsize=8)
ax_label.grid(True, alpha=0.3)
ax_label.set_ylim(-0.5, 2.5)
line_label, = ax_label.plot([], [], 'ro-', lw=1.5, markersize=4)
line_label_data = line_label
ax_label.set_title('Ch64 - Label Value', fontsize=9)
# Ch65 标签序号子图
ax_seq = fig.add_subplot(gs[len(PLOT_CHANNELS) + 1])
axes.append(ax_seq)
ax_seq.set_ylabel('Label Seq', fontsize=8)
ax_seq.set_xlabel('Time (samples)', fontsize=8)
ax_seq.grid(True, alpha=0.3)
ax_seq.set_ylim(-0.5, 10)
line_seq, = ax_seq.plot([], [], 'gs-', lw=1.5, markersize=4)
line_seq_data = line_seq
ax_seq.set_title('Ch65 - Label Sequence', fontsize=9)
plt.tight_layout()
# ===== 状态 =====
global_idx = 0 # 全局采样点索引
label_events = [] # 捕获的标签事件
start_time = None
fft_done = False
fft_buffer = [] # 暂存前 250 点做 FFT
ch64_zero_ok = True # 验证 Ch64 非标签采样点均为 0
ch65_zero_ok = True # 验证 Ch65 非标签采样点均为 0
label_pos_ok_all = True # 验证标签均在包内索引 4
# ===== 数据缓冲区 =====
max_samples = int(FS * PLOT_WINDOW_SEC)
eeg_buffer = {ch: np.zeros(max_samples) for ch in PLOT_CHANNELS}
label_buffer = np.zeros(max_samples)
seq_buffer = np.zeros(max_samples)
time_axis = np.arange(max_samples)
# ZMQ 收发统计
recv_count = 0
try:
# 首次 pause 用于显示窗口
plt.pause(0.5)
print(f"[INFO] 交互窗口已显示,如未看到请检查任务栏")
while True:
# ROUTER recv: prepended 一个 identity 帧
# datamock 发送 3帧 [b'datamock', b'', data_bytes]
# ROUTER 接收后变成 4帧 [router_identity, b'datamock', b'', data_bytes]
frames = sock.recv_multipart()
recv_count += 1
now = time.time()
if start_time is None:
start_time = now
# 帧格式: [router_identity, b'datamock', b'', data_bytes]
router_id = frames[0] # ROUTER 添加的身份帧
identity = frames[1] # 发送端的 identity
_empty = frames[2] # 空帧
raw_data = frames[3] # 实际数据字节
# 数据长度校验
if len(raw_data) != EXPECTED_BYTES:
print(f"[ERROR] 数据长度错误: 期望{EXPECTED_BYTES}字节, 实际{len(raw_data)}字节")
continue
# 解析为 [5, 66] float32 数组
packet = np.frombuffer(raw_data, dtype=np.float32).reshape(N_SAMPLES_PER_PKT, N_CHAN)
elapsed = now - start_time
# ===== 验证 1: 数据形状 =====
if recv_count == 1:
shape_ok = packet.shape == (N_SAMPLES_PER_PKT, N_CHAN)
print(f"[{'' if shape_ok else ''}] 数据形状: {packet.shape} "
f"(期望 [{N_SAMPLES_PER_PKT}, {N_CHAN}])")
if not shape_ok:
print(f" ✗ 形状不匹配,退出")
break
# ===== 验证 2: EEG 幅值(首包) =====
if recv_count == 1:
eeg = packet[:, :64]
amp_mean = np.mean(np.abs(eeg))
amp_ok = amp_mean <= EEG_AMP_MEAN * EEG_AMP_TOLERANCE
print(f"[{'' if amp_ok else ''}] EEG 幅值: 均值={amp_mean:.2f}μV "
f"(期望 ~{EEG_AMP_MEAN:.2f}μV峰值 ~{EEG_AMP:.2f}μV)")
if not amp_ok:
print(f" ✗ 幅值超出容差范围")
# ===== 验证 3: EEG 频率(首秒数据收集满后做 FFT =====
fft_buffer.append(packet[:, 0].copy()) # 收集 Ch0
if not fft_done and len(fft_buffer) * N_SAMPLES_PER_PKT >= FFT_SAMPLES:
# 凑够 250 点,做 FFT
all_ch0 = np.concatenate(fft_buffer)[:FFT_SAMPLES]
peak_freq, fft_mag, freqs = validate_fft(all_ch0)
freq_ok = abs(peak_freq - EEG_FREQ) < 1.0
print(f"[{'' if freq_ok else ''}] EEG 频率: 峰值={peak_freq:.1f}Hz "
f"(期望 ~{EEG_FREQ}Hz)")
print(f" FFT 幅度谱前 5 峰值:")
top5 = np.argsort(fft_mag[1:])[-5:][::-1] + 1
for rank, idx in enumerate(top5):
print(f" {rank+1}. {freqs[idx]:.1f}Hz 幅度={fft_mag[idx]:.1f}")
print()
fft_done = True
# ===== 验证 4: 标签通道Ch64/Ch65 =====
ch64 = packet[:, 64]
ch65 = packet[:, 65]
ch64_nonzero = np.where(ch64 != 0)[0]
ch65_nonzero = np.where(ch65 != 0)[0]
# 检查非标签采样点是否全为 0
ch64_zeros = np.all(ch64[:4] == 0)
ch65_zeros = np.all(ch65[:4] == 0)
ch64_zero_ok = ch64_zero_ok and ch64_zeros
ch65_zero_ok = ch65_zero_ok and ch65_zeros
if len(ch64_nonzero) > 0:
pos_in_pkt = int(ch64_nonzero[0])
label_val = int(ch64[pos_in_pkt])
label_seq = int(ch65[pos_in_pkt])
pos_ok = (len(ch64_nonzero) == 1 and pos_in_pkt == 4)
label_pos_ok_all = label_pos_ok_all and pos_ok
elapsed_since_start = now - start_time
print(f"[✓] 标签触发 @ {elapsed_since_start:.1f}s "
f"(global_idx={global_idx}{recv_count})")
print(f" Ch64 标签值: {label_val} Ch65 序号: {label_seq}")
print(f" 包内位置: 采样点 {pos_in_pkt}/4 "
f"({'' if pos_ok else '✗ 期望 4'}) "
f"其余采样点 Ch64=0: {'' if ch64_zeros else ''} "
f"Ch65=0: {'' if ch65_zeros else ''}")
print()
label_events.append({
'time': elapsed_since_start,
'label': label_val,
'seq': label_seq
})
global_idx += N_SAMPLES_PER_PKT
# ===== 更新绘图缓冲区 =====
for ch_idx, ch in enumerate(PLOT_CHANNELS):
eeg_buffer[ch] = np.roll(eeg_buffer[ch], -N_SAMPLES_PER_PKT)
eeg_buffer[ch][-N_SAMPLES_PER_PKT:] = packet[:, ch]
label_buffer = np.roll(label_buffer, -N_SAMPLES_PER_PKT)
label_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 64]
seq_buffer = np.roll(seq_buffer, -N_SAMPLES_PER_PKT)
seq_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 65]
# ===== 实时更新绘图 =====
for i, ch in enumerate(PLOT_CHANNELS):
lines_eeg[i].set_data(time_axis, eeg_buffer[ch]) # 数据已是 μV 单位
line_label_data.set_data(time_axis, label_buffer)
line_seq_data.set_data(time_axis, seq_buffer)
# 设置 x 轴范围
for ax in axes:
ax.set_xlim(0, max_samples)
# 刷新图形(交互模式)
fig.canvas.draw_idle()
plt.pause(0.001)
except KeyboardInterrupt:
print("\n" + "=" * 55)
print(" 验证结果汇总")
print("=" * 55)
print(f" 运行时长: {time.time() - start_time:.1f}s")
print(f" 收到包数: {recv_count}")
print(f" FFT 验证: {'✓ 已完成' if fft_done else '✗ 未完成时长不足1s'}")
print(f" 非标签采样点 Ch64=0: {'' if ch64_zero_ok else ''}")
print(f" 非标签采样点 Ch65=0: {'' if ch65_zero_ok else ''}")
print(f" 标签均在包内位置4: {'' if label_pos_ok_all else ''}")
if label_events:
print(f"\n 共捕获 {len(label_events)} 次标签事件:")
for i, ev in enumerate(label_events):
print(f" {i+1}. t={ev['time']:.1f}s label={ev['label']} 序号={ev['seq']}")
# 标签间隔
print(f"\n 标签间隔验证 (期望 ~{LABEL_INTERVAL}s):")
for i in range(1, len(label_events)):
dt = label_events[i]['time'] - label_events[i-1]['time']
ok = abs(dt - LABEL_INTERVAL) < 0.1
print(f" {i}->{i+1}: {dt:.2f}s {'' if ok else ''}")
# 标签交替
labels = [e['label'] for e in label_events]
alt_ok = all(labels[i] != labels[i+1] for i in range(len(labels) - 1))
print(f"\n 标签交替: {labels} {'✓ 交替正确' if alt_ok else '✗ 交替错误'}")
# 序号
label1_seqs = [e['seq'] for e in label_events if e['label'] == 1]
label2_seqs = [e['seq'] for e in label_events if e['label'] == 2]
s1_ok = label1_seqs == list(range(1, len(label1_seqs) + 1))
s2_ok = label2_seqs == list(range(1, len(label2_seqs) + 1))
print(f" label=1 序号: {label1_seqs} {'' if s1_ok else ''}")
print(f" label=2 序号: {label2_seqs} {'' if s2_ok else ''}")
else:
print(f"\n 未捕获标签事件(运行时长不足 {LABEL_INTERVAL}s")
print("=" * 55)
finally:
sock.close()
ctx.term()
plt.ioff()
plt.close('all')
try:
root.destroy()
except:
pass
if __name__ == '__main__':
main()