Files
bci_algo/Decoder.py
2026-06-12 11:33:48 +08:00

497 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import ast
import glob
import os
import sys
import threading
from datetime import datetime
import multiprocessing as mp
import numpy as np
import time
import torch
from queue import Empty
from scipy import signal
from torch.autograd import Variable
# from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
# from concentration.algorithm.calculate_focus import Calculate
# from blinkdetection.algorithm.eye_detection import blink_detection
from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
from SSVEP.dwfbcca import FbccaDw
# from Tools.plot_MI_EEG import plotMain
from collections import deque
from Zmq.filterProcess import SlidingFilter
save_train_data = int(IniRead('system', 'save_train_data', 0))
def get_root_path():
"""
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
threading.Thread.__init__(self)
self.device_info = device_info
self.Runing=True
self.decoder = None
self.decoder_class = None #解码器类别
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
self.zmqServer = zmqServer(device_info=self.device_info)
self.zmqServer.start() # 启动ZMQ接收线程
self.sliding_filter = SlidingFilter(
ring_buffer=self.zmqServer.filterBuffer,
n_chan=self.zmqServer.device_info['channel_nums'],
srate=self.zmqServer.device_info['sample_rate']
)
# 注册滤波结果回调(示例:打印数据形状)
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
if energy > threshold:
return False
return True
def init_Decoder(self,decoder_class):
'''
初始化解码器
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
:return:
'''
self.decoder_class = decoder_class
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.n_chan = 8
# self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
self.zmqServer.interval_init(decoder_class)
self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
self.zmqServer.interval_init(decoder_class)
self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位
self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目
self.parameter_init(8, 30)
# elif decoder_class == 'concentration':
# self.thread_data_server.interval_inited = False
# self.n_chan = 6
# self.win_len = 10
# self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
# self.interval_epoch = [0, 1]
# self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class
# elif decoder_class == 'blink':
# self.n_chan = 2
# self.l_freq = 0.1 # 带通滤波器低频截止
# self.h_freq = 8.0 # 带通滤波器高频截止
# self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms)
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
# self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size)
# self.sample_counter = 0
# # 预计算滤波器系数,避免在循环中重复设计
# self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
# self.blink_count = 0 # 单次眨眼的次数
# self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
# self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
# self.double_blink_count = 0 # 连续两次眨眼的次数
# self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = []
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
self.plotLabel = [] #报告分析标签
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
os.remove(old_pth)
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
self.modelPath = ''.join([filePath, fileName, '.pth'])
self.mp_data_queue = mp.Queue()
self.mp_result_queue = mp.Queue()
def preprocess(self, signal_data):
# # 计算每行的平均值
row_means = np.mean(signal_data, axis=-1, keepdims=True)
# 对每一行去均值
signal_data = signal_data - row_means
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
return signal_data
def run(self):
while self.Runing:
# 当滤波数据大于5秒时启动滤波线程
if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
algo_log("启动滤波线程", level="DEBUG")
self.sliding_filter.start()
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG")
self.zmqServer.decoder_switch = False
self.zmqServer.changeTarget = False
self.reset_state() # 切换前先统一清理旧状态
self.init_Decoder(self.zmqServer.decoder_class)
# 同步信息
if self.zmqServer.state_mode == 'sync':
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest'
try:
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
else:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
continue;
self.zmqServer.paradigmBuffer.getData(25)
except Exception as e:
algo_log(f"Decoder Loop Error: {e}")
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
self.zmqServer.paradigmBuffer.resetAllPara()
algo_log('启动SSVEP预测', level="DEBUG")
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
return
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
# algo_log(f"SSVEP取出的{data.shape}, data = {data[:20]}", level="DEBUG")
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热
self.decodingSteps = 2
algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3
algo_log('SSVEP预测结果' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
self.zmqServer.broadcast_message('result', int(choosenNum))
self.decodingSteps = 0
algo_log('SSVEP发送给界面完成。', level="DEBUG")
def decoder_SSMVEP(self):
'''模型训练'''
if self.load_model == False and all(
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
algo_log(f"开始SSMVEP模型训练数据形状{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
if save_train_data == 1:
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"SSMVEP模型训练完成时间{formatted_time}", level="DEBUG")
self.load_model = True
self.zmqServer.broadcast_message('paradigm', 1)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
else:
time.sleep(0.0001)
return
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
time.sleep(0.01)
return
else: # 已有模型
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
time.sleep(0.0001)
return
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
self.zmqServer.broadcast_message('result', int(choosenNum))
algo_log("SSMVEP发送给界面完成。", level="DEBUG")
else: # 休息状态
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.zmqServer.paradigmBuffer.getData(25)
def decoder_MI(self):
'''模型训练'''
if self.train_started == False and all(
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
self.train_started = True
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
algo_log(f"MI开始训练训练集{np.shape(self.trainData)}标签shape{np.shape(self.trainLabel)}", level="DEBUG")
if save_train_data == 1:
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
'n_chan': self.n_chan})
'''检查模型是否训练完成,调用'''
if self.load_model == False and self.train_started == True:
try:
result = self.mp_result_queue.get_nowait()
if result['status'] == 'success':
algo_log("MI模型训练完成加载新模型", level="DEBUG")
# 调用模型
self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval()
# 模型预热
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
warmup_data = torch.from_numpy(warmup_data)
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
_ = self.model(warmup_data)
self.load_model = True
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
else:
algo_log("MI训练失败: " + result['msg'], level="DEBUG")
except Empty:
pass # 还没完成
except Exception as e:
algo_log("MI模型训练失败: " + str(e), level="DEBUG")
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.0001)
return
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
time.sleep(0.0001)
return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
self.plotData.append(
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
algo_log(f"MI运动意图识别: {y_pred}")
self.zmqServer.broadcast_message('result', int(y_pred.item()))
end = time.time()
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.zmqServer.paradigmBuffer.getData(25)
# def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict':
# if self.zmqServer.StartDecode:
# self.zmqServer.StartDecode = False
# self.thread_data_server.ResetAll()
# now = datetime.now()
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
# print('启动专注力预测 ', formatted_time)
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
# return
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
# result = self.calculate.queueOpt(data)
# if result is not None:
# self.zmqClient.send_to_all('result', int(result))
# else: # 休息状态
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# if self.thread_data_server.GetDataLenCount() < 25:
# time.sleep(0.005)
# return
# self.thread_data_server.getData(25)
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
self.sliding_filter.stop()
self.Runing=False
def reset_state(self):
"""清空解码器状态和缓存数据"""
# 重置设备层缓存
self.zmqServer.reset_state()
# 重置解码状态
self.decodingSteps = 0
self.calculateCount = 0
# 重置训练数据
self.plotData = []
self.plotLabel = []
self.trainData = []
self.trainLabel = []
self.currentLabel = -1
self.train_started = False
self.load_model = False
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
if hasattr(self, 'mp_data_queue'):
while not self.mp_data_queue.empty():
try: self.mp_data_queue.get_nowait()
except Empty: pass
if hasattr(self, 'mp_result_queue'):
while not self.mp_result_queue.empty():
try: self.mp_result_queue.get_nowait()
except Empty: pass