update
This commit is contained in:
127
Decoder.py
127
Decoder.py
@@ -26,6 +26,8 @@ from SSVEP.dwfbcca import FbccaDw
|
||||
from collections import deque
|
||||
from Zmq.filterProcess import SlidingFilter
|
||||
|
||||
save_train_data = int(IniRead('system', 'save_train_data', 0))
|
||||
|
||||
def get_root_path():
|
||||
"""
|
||||
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
|
||||
@@ -209,11 +211,10 @@ class Decoder_main(threading.Thread):
|
||||
elif self.decoder_class == 'mi':
|
||||
self.decoder_MI()
|
||||
else:
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
continue;
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
continue;
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
except Exception as e:
|
||||
algo_log(f"Decoder Loop Error: {e}")
|
||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||
@@ -223,31 +224,32 @@ class Decoder_main(threading.Thread):
|
||||
self.zmqServer.StartDecode = False
|
||||
self.decodingSteps = 1
|
||||
self.zmqServer.paradigmBuffer.resetAllPara()
|
||||
print('启动预测')
|
||||
algo_log('启动SSVEP预测', level="DEBUG")
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
||||
return
|
||||
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
||||
algo_log(f"SSVEP取出的:{data.shape}, data = {data[:20]}", level="DEBUG")
|
||||
data = data[:self.n_chan, :]
|
||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||
self.dw.warmFilter(data) # 预热
|
||||
self.decodingSteps = 2
|
||||
print('预热数据完成。开始预测')
|
||||
algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
|
||||
return
|
||||
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
||||
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
||||
self.calculateCount += 1
|
||||
if choosenNum != -1 and self.is_valid_signal(data):
|
||||
self.decodingSteps = 3
|
||||
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
||||
algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
|
||||
self.calculateCount = 0
|
||||
if self.decodingSteps == 3: # 发送解码后的信息
|
||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||
self.decodingSteps = 0
|
||||
print('发送给界面完成。')
|
||||
algo_log('SSVEP发送给界面完成。', level="DEBUG")
|
||||
|
||||
def decoder_SSMVEP(self):
|
||||
'''模型训练'''
|
||||
@@ -255,34 +257,28 @@ class Decoder_main(threading.Thread):
|
||||
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
||||
self.trainData = np.array(self.trainData)
|
||||
self.trainLabel = np.array(self.trainLabel)
|
||||
print(np.shape(self.trainData), (self.trainLabel))
|
||||
# 保存多个数组到文件
|
||||
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
|
||||
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||
algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
|
||||
if save_train_data == 1:
|
||||
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
save_path = f"{now_str}.npz"
|
||||
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
|
||||
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('模型训练完成', formatted_time)
|
||||
algo_log(f"SSMVEP模型训练完成,时间:{formatted_time}", level="DEBUG")
|
||||
self.load_model = True
|
||||
self.zmqServer.broadcast_message('paradigm', 1)
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||
|
||||
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
@@ -301,15 +297,14 @@ class Decoder_main(threading.Thread):
|
||||
self.zmqServer.StartDecode = False
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx])
|
||||
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
@@ -320,16 +315,14 @@ class Decoder_main(threading.Thread):
|
||||
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||
if isinstance(choosenNum, np.ndarray):
|
||||
choosenNum = choosenNum[0]
|
||||
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||
algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
|
||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||
print('发送给界面完成。')
|
||||
algo_log("SSMVEP发送给界面完成。", level="DEBUG")
|
||||
else: # 休息状态
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
|
||||
def decoder_MI(self):
|
||||
'''模型训练'''
|
||||
@@ -339,7 +332,11 @@ class Decoder_main(threading.Thread):
|
||||
self.train_started = True
|
||||
self.trainData = np.array(self.trainData)
|
||||
self.trainLabel = np.array(self.trainLabel) + 1
|
||||
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
|
||||
algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG")
|
||||
if save_train_data == 1:
|
||||
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
save_path = f"{now_str}.npz"
|
||||
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
|
||||
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
||||
p.start()
|
||||
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
||||
@@ -350,7 +347,7 @@ class Decoder_main(threading.Thread):
|
||||
try:
|
||||
result = self.mp_result_queue.get_nowait()
|
||||
if result['status'] == 'success':
|
||||
print("模型训练完成,加载新模型")
|
||||
algo_log("MI模型训练完成,加载新模型", level="DEBUG")
|
||||
# 调用模型
|
||||
self.model = torch.load(self.modelPath, weights_only=False)
|
||||
self.model.eval()
|
||||
@@ -363,45 +360,42 @@ class Decoder_main(threading.Thread):
|
||||
self.load_model = True
|
||||
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
else:
|
||||
print("训练失败:", result['msg'])
|
||||
algo_log("MI训练失败: " + result['msg'], level="DEBUG")
|
||||
except Empty:
|
||||
pass # 还没完成
|
||||
except Exception as e:
|
||||
print('模型调用失败: ', e)
|
||||
algo_log("MI模型训练失败: " + str(e), level="DEBUG")
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.interval_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
else:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
print('训练集:', np.shape(self.trainData))
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||
if self.zmqServer.StartDecode:
|
||||
self.zmqServer.StartDecode = False
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动预测 ', formatted_time)
|
||||
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
|
||||
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
@@ -409,7 +403,7 @@ class Decoder_main(threading.Thread):
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx])
|
||||
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
start = time.time()
|
||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
@@ -426,16 +420,15 @@ class Decoder_main(threading.Thread):
|
||||
Cls = self.model(test_data)
|
||||
y_pred = torch.max(Cls, 1)[1]
|
||||
self.plotLabel.append(int(y_pred.item()))
|
||||
print('运动意图识别: ', y_pred)
|
||||
algo_log(f"MI运动意图识别: {y_pred}")
|
||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
||||
end = time.time()
|
||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
|
||||
# def decoder_concentration(self):
|
||||
# if self.zmqServer.state_mode == 'predict':
|
||||
|
||||
Reference in New Issue
Block a user