@@ -26,6 +26,8 @@ from SSVEP.dwfbcca import FbccaDw
from collections import deque
from Zmq . filterProcess import SlidingFilter
save_train_data = int ( IniRead ( ' system ' , ' save_train_data ' , 0 ) )
def get_root_path ( ) :
"""
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
@@ -209,11 +211,10 @@ class Decoder_main(threading.Thread):
elif self . decoder_class == ' mi ' :
self . decoder_MI ( )
else :
if self . zmqServer . open_Impedance == False : # 非阻抗检测状态
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 2 5:
time . sleep ( 0.005 )
continue ;
self . zmqServer . paradigmBuffer . getData ( 25 )
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 25 :
time . sleep ( 0.00 5)
continue ;
self . zmqServer . paradigmBuffer . getData ( 25 )
except Exception as e :
algo_log ( f " Decoder Loop Error: { e } " )
time . sleep ( 0.1 ) # Prevent CPU spin if error is persistent
@@ -223,31 +224,32 @@ class Decoder_main(threading.Thread):
self . zmqServer . StartDecode = False
self . decodingSteps = 1
self . zmqServer . paradigmBuffer . resetAllPara ( )
print ( ' 启动预测 ' )
algo_log ( ' 启动SSVEP预测 ' , level = " DEBUG " )
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 50 :
time . sleep ( 0.005 )
return
if self . zmqServer . open_Impedance : # 阻抗检测状态不解码
return
data = self . zmqServer . paradigmBuffer . getDataViaSSVEP ( 50 )
algo_log ( f " SSVEP取出的: { data . shape } , data = { data [ : 20 ] } " , level = " DEBUG " )
data = data [ : self . n_chan , : ]
if self . decodingSteps == 1 and hasattr ( self , ' dw ' ) : # 开始预热
self . dw . onlineInit ( ) # 刺激闪烁的第1s重置 --在线数据采集时
self . dw . warmFilter ( data ) # 预热
self . decodingSteps = 2
print ( ' 预热数据完成。开始预测 ' )
algo_log ( ' SSVEP 预热数据完成。开始预测' , level = " DEBUG " )
return
if self . decodingSteps == 2 and hasattr ( self , ' dw ' ) : # 解码中
choosenNum = self . dw . fbccaDWMW ( data , self . referenceData , self . DW_cost_tv , self . calculateCount )
self . calculateCount + = 1
if choosenNum != - 1 and self . is_valid_signal ( data ) :
self . decodingSteps = 3
print ( ' 预测结果: ' + str ( choosenNum ) + ' ,计算次数: ' + str ( self . calculateCount ) )
algo_log ( ' SSVEP 预测结果:' + str ( choosenNum ) + ' ,计算次数: ' + str ( self . calculateCount ) , level = " DEBUG " )
self . calculateCount = 0
if self . decodingSteps == 3 : # 发送解码后的信息
self . zmqServer . broadcast_message ( ' result ' , int ( choosenNum ) )
self . decodingSteps = 0
print ( ' 发送给界面完成。 ' )
algo_log ( ' SSVEP 发送给界面完成。' , level = " DEBUG " )
def decoder_SSMVEP ( self ) :
''' 模型训练 '''
@@ -255,34 +257,28 @@ class Decoder_main(threading.Thread):
self . trainLabel . count ( i ) > = self . single_train for i in range ( len ( self . list_freqs ) ) ) : # 模型尚未训练完成
self . trainData = np . array ( self . trainData )
self . trainLabel = np . array ( self . trainLabel )
print ( np . shape ( self . trainData ) , ( self . trainLabel ) )
# 保存多个数组到文件
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel )
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
algo_log ( f " 开始SSMVEP模型训练, 数据形状: { np . shape ( self . trainData ) } ,标签形状: { self . trainLabel . shape } " , level = " DEBUG " )
if save_train_data == 1 :
now_str = datetime . now ( ) . strftime ( " % Y % m %d _ % H % M % S " )
save_path = f " { now_str } .npz "
np . savez ( save_path , array1 = self . trainData , array2 = self . trainLabel )
self . decoder = self . tdca . fit ( self . trainData , self . trainLabel , Yf = self . Yf )
now = datetime . now ( )
formatted_time = now . strftime ( ' % H: % M: % S. %f ' ) [ : - 3 ]
print ( ' 模型训练完成 ' , formatted_time )
algo_log ( f " SSMVEP模型训练完成, 时间: { formatted_time } " , level = " DEBUG " )
self . load_model = True
self . zmqServer . broadcast_message ( ' paradigm ' , 1 )
''' 训练阶段采集数据 '''
if self . zmqServer . state_mode == ' train ' : # 训练状态
if self . zmqServer . epoch_finished and self . zmqServer . paradigmBuffer . GetDataLenCount ( ) > = \
self . train_epoch [ 1 ] + self . zmqServer . event_inner_idx :
self . currentLabel = self . zmqServer . currentLabel
print ( ' 训练队列数据: ' , self . zmqServer . paradigmBuffer . GetDataLenCount ( ) )
trainTrial = self . zmqServer . paradigmBuffer . get_SSMVEPData ( ) # 取出所有数据
print ( ' 取出的: ' , trainTrial . shape , ' event: ' , trainTrial [ - 2 , self . zmqServer . event_inner_idx ] )
algo_log ( f " 取出的: { trainTrial . shape } , event: { trainTrial [ - 2 , self . zmqServer . event_inner_idx ] } " , level = " DEBUG " )
trainTrial = self . preprocess ( trainTrial [ : self . n_chan , : ] ) # 预处理
trainTrial = trainTrial [ : , self . zmqServer . event_inner_idx + self . train_epoch [
0 ] : self . zmqServer . event_inner_idx + self . train_epoch [ 1 ] ]
print ( ' trial: ' , self . zmqServer . event_inner_idx , self . train_epoch [ 0 ] , self . train_epoch [ 1 ] )
if trainTrial . shape [ 1 ] == ( self . train_epoch [ 1 ] - self . train_epoch [ 0 ] ) and isinstance (
self . trainLabel , list ) \
and self . trainLabel . count ( self . currentLabel ) < self . single_train :
@@ -301,15 +297,14 @@ class Decoder_main(threading.Thread):
self . zmqServer . StartDecode = False
now = datetime . now ( )
formatted_time = now . strftime ( ' % H: % M: % S. %f ' ) [ : - 3 ]
print ( ' 启动预测 ' , formatted_time )
algo_log ( f " SSMVEP模型启动预测 { formatted_time } " , level = " DEBUG " )
if self . zmqServer . epoch_finished == False or self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < \
self . interval_epoch [ 1 ] \
+ self . zmqServer . event_inner_idx :
time . sleep ( 0.0001 )
return
data = self . zmqServer . paradigmBuffer . get_SSMVEPData ( ) # 读取全部数据
print ( ' 取出的: ' , data . shape , ' event: ' , data [ - 2 , self . zmqServer . event_inner_idx ] )
algo_log ( f " 取出的: { data . shape } , event: { data [ - 2 , self . zmqServer . event_inner_idx ] } " , level = " DEBUG " )
data = self . preprocess ( data [ : self . n_chan , : ] ) # 预处理
data = data [ : ,
self . zmqServer . event_inner_idx + self . interval_epoch [
@@ -320,16 +315,14 @@ class Decoder_main(threading.Thread):
choosenNum , features_2 = self . decoder . predict ( pad_eeg_test )
if isinstance ( choosenNum , np . ndarray ) :
choosenNum = choosenNum [ 0 ]
print ( ' 结果: ' , choosenNum , ' rho: ' , sorted ( features_2 [ 0 ] ) ,
sorted ( features_2 [ 0 ] ) [ - 1 ] - sorted ( features_2 [ 0 ] ) [ - 2 ] )
algo_log ( f " 结果: { choosenNum } , rho: { sorted ( features_2 [ 0 ] ) [ - 1 ] - sorted ( features_2 [ 0 ] ) [ - 2 ] } " , level = " DEBUG " )
self . zmqServer . broadcast_message ( ' result ' , int ( choosenNum ) )
print ( ' 发送给界面完成。 ' )
algo_log ( " SSMVEP发送给界面完成。 " , level = " DEBUG " )
else : # 休息状态
if self . zmqServer . open_Impedance == False : # 非阻抗检测状态
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 2 5:
time . sleep ( 0.005 )
return
self . zmqServer . paradigmBuffer . getData ( 25 )
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 25 :
time . sleep ( 0.00 5)
return
self . zmqServer . paradigmBuffer . getData ( 25 )
def decoder_MI ( self ) :
''' 模型训练 '''
@@ -339,7 +332,11 @@ class Decoder_main(threading.Thread):
self . train_started = True
self . trainData = np . array ( self . trainData )
self . trainLabel = np . array ( self . trainLabel ) + 1
# print('训练集:',np.shape(self.trainData), (self.trainLabel) )
algo_log ( f " MI开始训练, 训练集: { np . shape ( self . trainData ) } , 标签shape: { np . shape ( self . trainLabel ) } " , level = " DEBUG " )
if save_train_data == 1 :
now_str = datetime . now ( ) . strftime ( " % Y % m %d _ % H % M % S " )
save_path = f " { now_str } .npz "
np . savez ( save_path , array1 = self . trainData , array2 = self . trainLabel )
p = mp . Process ( target = onlineTrain , args = ( self . mp_data_queue , self . mp_result_queue ) ) # 开启子进程,训练模型
p . start ( )
self . mp_data_queue . put ( { ' data ' : self . trainData , ' label ' : self . trainLabel , ' modelPath ' : self . modelPath ,
@@ -350,7 +347,7 @@ class Decoder_main(threading.Thread):
try :
result = self . mp_result_queue . get_nowait ( )
if result [ ' status ' ] == ' success ' :
print ( " 模型训练完成,加载新模型 " )
algo_log ( " MI 模型训练完成,加载新模型" , level = " DEBUG " )
# 调用模型
self . model = torch . load ( self . modelPath , weights_only = False )
self . model . eval ( )
@@ -363,45 +360,42 @@ class Decoder_main(threading.Thread):
self . load_model = True
self . zmqServer . broadcast_message ( ' paradigm ' , 1 ) # 模型调用完毕,通知上位机
else :
print ( " 训练失败: " , result [ ' msg ' ] )
algo_log ( " MI 训练失败: " + result [ ' msg ' ] , level = " DEBUG " )
except Empty :
pass # 还没完成
except Exception as e :
print ( ' 模型调用失败: ' , e )
algo_log ( " MI模型训练失败: " + str ( e ) , level = " DEBUG " )
''' 训练阶段采集数据 '''
if self . zmqServer . state_mode == ' train ' and self . train_started == False : # 训练状态
if self . zmqServer . StartTrain :
self . currentLabel = self . zmqServer . currentLabel
self . zmqServer . StartTrain = False
if self . zmqServer . epoch_finished == False or self . zmqServer . paradigmBuffer . G etDataLenCount ( ) < \
self . int erval_epoch [ 1 ] \
+ self . zmqServer . event_inner_idx :
if self . zmqServer . epoch_finished and self . zmqServer . paradigmBuffer . GetDataLenCount ( ) > = \
self . interval_epoch [ 1 ] + self . zmqServer . event_inner_idx :
algo_log ( f " 训练队列数据: { self . zmqServer . paradigmBuffer . GetDataLenCount ( ) } " , level = " DEBUG " )
originalTrial = self . zmqServer . paradigmBuffer . g et_MI Data( ) # 取出MI导联数据
algo_log ( f " 取出的: { originalTrial . shape } ,event: { originalTrial [ - 2 , self . zmqServer . event_ inn er_idx ] } " , level = " DEBUG " )
trainTrial = self . preprocess ( originalTrial [ : self . n_chan , : ] ) # 预处理
trainTrial = trainTrial [ : , self . zmqServer . event_inner_idx + self . interval_epoch [
0 ] : self . zmqServer . event_inner_idx + self . interval_epoch [ 1 ] ]
algo_log ( f " trial: { self . zmqServer . event_inner_idx } , { self . interval_epoch [ 0 ] } , { self . interval_epoch [ 1 ] } " , level = " DEBUG " )
if trainTrial . shape [ 1 ] == ( self . interval_epoch [ 1 ] - self . interval_epoch [ 0 ] ) and isinstance ( self . trainLabel ,
list ) \
and self . trainLabel . count ( self . currentLabel ) < self . single_train :
self . trainData . append ( trainTrial )
self . trainLabel . append ( self . currentLabel )
algo_log ( f " 训练集: { np . shape ( self . trainData ) } " , level = " DEBUG " )
self . plotData . append ( originalTrial [ : self . n_chan , self . zmqServer . event_inner_idx + self . interval_epoch [
0 ] : self . zmqServer . event_inner_idx + self . interval_epoch [ 1 ] ] )
self . plotLabel . append ( self . currentLabel )
else :
time . sleep ( 0.0001 )
return
print ( ' 训练队列数据: ' , self . zmqServer . paradigmBuffer . GetDataLenCount ( ) )
originalTrial = self . zmqServer . paradigmBuffer . get_MIData ( ) # 取出MI导联数据
print ( ' 取出的: ' , originalTrial . shape , ' event: ' , originalTrial [ - 2 , self . zmqServer . event_inner_idx ] )
trainTrial = self . preprocess ( originalTrial [ : self . n_chan , : ] ) # 预处理
trainTrial = trainTrial [ : , self . zmqServer . event_inner_idx + self . interval_epoch [
0 ] : self . zmqServer . event_inner_idx + self . interval_epoch [ 1 ] ]
print ( ' trial: ' , self . zmqServer . event_inner_idx , self . interval_epoch [ 0 ] , self . interval_epoch [ 1 ] )
if trainTrial . shape [ 1 ] == ( self . interval_epoch [ 1 ] - self . interval_epoch [ 0 ] ) and isinstance ( self . trainLabel ,
list ) \
and self . trainLabel . count ( self . currentLabel ) < self . single_train :
self . trainData . append ( trainTrial )
self . trainLabel . append ( self . currentLabel )
print ( ' 训练集: ' , np . shape ( self . trainData ) )
self . plotData . append ( originalTrial [ : self . n_chan , self . zmqServer . event_inner_idx + self . interval_epoch [
0 ] : self . zmqServer . event_inner_idx + self . interval_epoch [ 1 ] ] )
self . plotLabel . append ( self . currentLabel )
elif self . zmqServer . state_mode == ' predict ' and self . load_model == True : # 测试状态
if self . zmqServer . StartDecode :
self . zmqServer . StartDecode = False
now = datetime . now ( )
formatted_time = now . strftime ( ' % H: % M: % S. %f ' ) [ : - 3 ]
print ( ' 启动预测 ' , formatted_time )
algo_log ( f " MI启动预测 { formatted_time } " , level = " DEBUG " )
if self . zmqServer . epoch_finished == False or self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < \
self . interval_epoch [ 1 ] \
@@ -409,7 +403,7 @@ class Decoder_main(threading.Thread):
time . sleep ( 0.0001 )
return
originalData = self . zmqServer . paradigmBuffer . get_MIData ( ) # 读取全部数据
print ( ' 取出的: ' , originalData . shape , ' event: ' , originalData [ - 2 , self . zmqServer . event_inner_idx ] )
algo_log ( f " 取出的: { originalData . shape } , event: { originalData [ - 2 , self . zmqServer . event_inner_idx ] } " , level = " DEBUG " )
start = time . time ( )
data = self . preprocess ( originalData [ : self . n_chan , : ] ) # 预处理
data = data [ : ,
@@ -426,16 +420,15 @@ class Decoder_main(threading.Thread):
Cls = self . model ( test_data )
y_pred = torch . max ( Cls , 1 ) [ 1 ]
self . plotLabel . append ( int ( y_pred . item ( ) ) )
print ( ' 运动意图识别: ' , y_pred )
algo_log ( f " MI 运动意图识别: { y_pred } " )
self . zmqServer . broadcast_message ( ' paradigm ' , int ( y_pred . item ( ) ) )
end = time . time ( )
print ( f ' 发送给界面完成,耗时 { end - start : .3f } s。 ' )
else : # 休息状态
if self . zmqServer . open_Impedance == False : # 非阻抗检测状态
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 2 5:
time . sleep ( 0.005 )
return
self . zmqServer . paradigmBuffer . getData ( 25 )
if self . zmqServer . paradigmBuffer . GetDataLenCount ( ) < 25 :
time . sleep ( 0.00 5)
return
self . zmqServer . paradigmBuffer . getData ( 25 )
# def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict':